Store embeddings
This commit is contained in:
parent
38e7d181f9
commit
0f4b6044c4
2 changed files with 46 additions and 13 deletions
51
track.py
51
track.py
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import pickle
|
||||||
import cv2
|
import cv2
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import tqdm
|
||||||
import motmetrics as mm
|
import motmetrics as mm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
|
||||||
logger.info('save results to {}'.format(filename))
|
logger.info('save results to {}'.format(filename))
|
||||||
|
|
||||||
|
|
||||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
|
||||||
'''
|
'''
|
||||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||||
|
|
||||||
|
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
The name(path) of the file for storing results.
|
The name(path) of the file for storing results.
|
||||||
|
|
||||||
save_dir : String
|
save_dir : String
|
||||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
|
||||||
|
save_figures : bool
|
||||||
|
If set, individual crops of all embedded figures will be saved
|
||||||
|
|
||||||
show_image : bool
|
show_image : bool
|
||||||
Option for shhowing individial frames during run-time.
|
Option for shhowing individial frames during run-time.
|
||||||
|
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
results = []
|
results = []
|
||||||
frame_id = 0
|
frame_id = -1
|
||||||
for path, img, img0 in dataloader:
|
for path, img, img0 in tqdm.tqdm(dataloader):
|
||||||
if frame_id % 20 == 0:
|
frame_id += 1
|
||||||
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
# if frame_id % 20 == 0:
|
||||||
|
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||||
|
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
|
||||||
|
if os.path.exists(frame_pickle_fn):
|
||||||
|
continue
|
||||||
# run tracking
|
# run tracking
|
||||||
timer.tic()
|
timer.tic()
|
||||||
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
||||||
online_targets = tracker.update(blob, img0)
|
# online targets: all tartgets that are not timed out
|
||||||
|
# frame_embeddings: the embeddings of objects visible only in the current frame
|
||||||
|
online_targets, frame_embeddings = tracker.update(blob, img0)
|
||||||
online_tlwhs = []
|
online_tlwhs = []
|
||||||
online_ids = []
|
online_ids = []
|
||||||
for t in online_targets:
|
for t in online_targets:
|
||||||
|
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
if show_image:
|
if show_image:
|
||||||
cv2.imshow('online_im', online_im)
|
cv2.imshow('online_im', online_im)
|
||||||
if save_dir is not None:
|
if save_dir is not None:
|
||||||
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
|
||||||
frame_id += 1
|
if save_img:
|
||||||
|
cv2.imwrite(base_fn+'.jpg', online_im)
|
||||||
|
if save_figures:
|
||||||
|
for i, fe in enumerate(frame_embeddings):
|
||||||
|
tlwh, curr_feat = fe
|
||||||
|
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
|
||||||
|
# print(x,y,w,h, tlwh)
|
||||||
|
crop_img = img0[y:y+h, x:x+w]
|
||||||
|
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
|
||||||
|
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
|
||||||
|
pickle.dump(fe, fp)
|
||||||
|
|
||||||
|
with open(frame_pickle_fn, 'wb') as fp:
|
||||||
|
pickle.dump(frame_embeddings, fp)
|
||||||
|
|
||||||
|
|
||||||
# save results
|
# save results
|
||||||
write_results(result_filename, results, data_type)
|
if result_filename is not None:
|
||||||
|
write_results(result_filename, results, data_type)
|
||||||
|
|
||||||
return frame_id, timer.average_time, timer.calls
|
return frame_id, timer.average_time, timer.calls
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
||||||
|
|
||||||
logger.info('start seq: {}'.format(seq))
|
# logger.info('start seq: {}'.format(seq))
|
||||||
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
||||||
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
||||||
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
||||||
|
|
|
@ -203,6 +203,8 @@ class JDETracker(object):
|
||||||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||||
removed_stracks = []
|
removed_stracks = []
|
||||||
|
|
||||||
|
frame_embeddings = []
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
''' Step 1: Network forward, get detections & embeddings'''
|
''' Step 1: Network forward, get detections & embeddings'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -220,8 +222,12 @@ class JDETracker(object):
|
||||||
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||||
|
# Surfacing Suspicion: extract features + frame id + bbox
|
||||||
|
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
detections = []
|
detections = []
|
||||||
|
frame_embeddings = []
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
# print('Forward: {} s'.format(t2-t1))
|
# print('Forward: {} s'.format(t2-t1))
|
||||||
|
@ -346,7 +352,7 @@ class JDETracker(object):
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||||
# print('Final {} s'.format(t5-t4))
|
# print('Final {} s'.format(t5-t4))
|
||||||
return output_stracks
|
return output_stracks, frame_embeddings
|
||||||
|
|
||||||
def joint_stracks(tlista, tlistb):
|
def joint_stracks(tlista, tlistb):
|
||||||
exists = {}
|
exists = {}
|
||||||
|
|
Loading…
Reference in a new issue