diff --git a/track.py b/track.py index f557c3d..26406f0 100644 --- a/track.py +++ b/track.py @@ -1,8 +1,11 @@ import os import os.path as osp +import pickle import cv2 import logging import argparse + +import tqdm import motmetrics as mm import torch @@ -38,7 +41,7 @@ def write_results(filename, results, data_type): logger.info('save results to {}'.format(filename)) -def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30): +def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30): ''' Processes the video sequence given and provides the output of tracking result (write the results in video file) @@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im The name(path) of the file for storing results. save_dir : String - Path to the folder for storing the frames containing bounding box information (Result frames). + Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle + save_figures : bool + If set, individual crops of all embedded figures will be saved show_image : bool Option for shhowing individial frames during run-time. @@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results = [] - frame_id = 0 - for path, img, img0 in dataloader: - if frame_id % 20 == 0: - logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time))) - + frame_id = -1 + for path, img, img0 in tqdm.tqdm(dataloader): + frame_id += 1 + # if frame_id % 20 == 0: + # logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time))) + frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl') + if os.path.exists(frame_pickle_fn): + continue # run tracking timer.tic() blob = torch.from_numpy(img).cuda().unsqueeze(0) - online_targets = tracker.update(blob, img0) + # online targets: all tartgets that are not timed out + # frame_embeddings: the embeddings of objects visible only in the current frame + online_targets, frame_embeddings = tracker.update(blob, img0) online_tlwhs = [] online_ids = [] for t in online_targets: @@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: - cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) - frame_id += 1 + base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id)) + if save_img: + cv2.imwrite(base_fn+'.jpg', online_im) + if save_figures: + for i, fe in enumerate(frame_embeddings): + tlwh, curr_feat = fe + x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3]) + # print(x,y,w,h, tlwh) + crop_img = img0[y:y+h, x:x+w] + cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img) + with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp: + pickle.dump(fe, fp) + + with open(frame_pickle_fn, 'wb') as fp: + pickle.dump(frame_embeddings, fp) + + # save results - write_results(result_filename, results, data_type) + if result_filename is not None: + write_results(result_filename, results, data_type) + return frame_id, timer.average_time, timer.calls @@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), for seq in seqs: output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None - logger.info('start seq: {}'.format(seq)) + # logger.info('start seq: {}'.format(seq)) dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size) result_filename = os.path.join(result_root, '{}.txt'.format(seq)) meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read() diff --git a/tracker/multitracker.py b/tracker/multitracker.py index eae4683..f5f4cfd 100644 --- a/tracker/multitracker.py +++ b/tracker/multitracker.py @@ -203,6 +203,8 @@ class JDETracker(object): lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing) removed_stracks = [] + frame_embeddings = [] + t1 = time.time() ''' Step 1: Network forward, get detections & embeddings''' with torch.no_grad(): @@ -220,8 +222,12 @@ class JDETracker(object): detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for (tlbrs, f) in zip(dets[:, :5], dets[:, 6:])] + # Surfacing Suspicion: extract features + frame id + bbox + frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections] + else: detections = [] + frame_embeddings = [] t2 = time.time() # print('Forward: {} s'.format(t2-t1)) @@ -346,7 +352,7 @@ class JDETracker(object): logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) # print('Final {} s'.format(t5-t4)) - return output_stracks + return output_stracks, frame_embeddings def joint_stracks(tlista, tlistb): exists = {}