Store embeddings
This commit is contained in:
parent
38e7d181f9
commit
0f4b6044c4
2 changed files with 46 additions and 13 deletions
49
track.py
49
track.py
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import cv2
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
import tqdm
|
||||
import motmetrics as mm
|
||||
|
||||
import torch
|
||||
|
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
|
|||
logger.info('save results to {}'.format(filename))
|
||||
|
||||
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
|
||||
'''
|
||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||
|
||||
|
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
The name(path) of the file for storing results.
|
||||
|
||||
save_dir : String
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
|
||||
save_figures : bool
|
||||
If set, individual crops of all embedded figures will be saved
|
||||
|
||||
show_image : bool
|
||||
Option for shhowing individial frames during run-time.
|
||||
|
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||
timer = Timer()
|
||||
results = []
|
||||
frame_id = 0
|
||||
for path, img, img0 in dataloader:
|
||||
if frame_id % 20 == 0:
|
||||
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||
|
||||
frame_id = -1
|
||||
for path, img, img0 in tqdm.tqdm(dataloader):
|
||||
frame_id += 1
|
||||
# if frame_id % 20 == 0:
|
||||
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
|
||||
if os.path.exists(frame_pickle_fn):
|
||||
continue
|
||||
# run tracking
|
||||
timer.tic()
|
||||
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
||||
online_targets = tracker.update(blob, img0)
|
||||
# online targets: all tartgets that are not timed out
|
||||
# frame_embeddings: the embeddings of objects visible only in the current frame
|
||||
online_targets, frame_embeddings = tracker.update(blob, img0)
|
||||
online_tlwhs = []
|
||||
online_ids = []
|
||||
for t in online_targets:
|
||||
|
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
if show_image:
|
||||
cv2.imshow('online_im', online_im)
|
||||
if save_dir is not None:
|
||||
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
||||
frame_id += 1
|
||||
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
|
||||
if save_img:
|
||||
cv2.imwrite(base_fn+'.jpg', online_im)
|
||||
if save_figures:
|
||||
for i, fe in enumerate(frame_embeddings):
|
||||
tlwh, curr_feat = fe
|
||||
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
|
||||
# print(x,y,w,h, tlwh)
|
||||
crop_img = img0[y:y+h, x:x+w]
|
||||
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
|
||||
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
|
||||
pickle.dump(fe, fp)
|
||||
|
||||
with open(frame_pickle_fn, 'wb') as fp:
|
||||
pickle.dump(frame_embeddings, fp)
|
||||
|
||||
|
||||
# save results
|
||||
if result_filename is not None:
|
||||
write_results(result_filename, results, data_type)
|
||||
|
||||
return frame_id, timer.average_time, timer.calls
|
||||
|
||||
|
||||
|
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
|
|||
for seq in seqs:
|
||||
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
||||
|
||||
logger.info('start seq: {}'.format(seq))
|
||||
# logger.info('start seq: {}'.format(seq))
|
||||
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
||||
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
||||
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
||||
|
|
|
@ -203,6 +203,8 @@ class JDETracker(object):
|
|||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||
removed_stracks = []
|
||||
|
||||
frame_embeddings = []
|
||||
|
||||
t1 = time.time()
|
||||
''' Step 1: Network forward, get detections & embeddings'''
|
||||
with torch.no_grad():
|
||||
|
@ -220,8 +222,12 @@ class JDETracker(object):
|
|||
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||
# Surfacing Suspicion: extract features + frame id + bbox
|
||||
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
|
||||
|
||||
else:
|
||||
detections = []
|
||||
frame_embeddings = []
|
||||
|
||||
t2 = time.time()
|
||||
# print('Forward: {} s'.format(t2-t1))
|
||||
|
@ -346,7 +352,7 @@ class JDETracker(object):
|
|||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||
# print('Final {} s'.format(t5-t4))
|
||||
return output_stracks
|
||||
return output_stracks, frame_embeddings
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
|
|
Loading…
Reference in a new issue