Store embeddings

This commit is contained in:
Ruben van de Ven 2023-04-05 17:17:15 +02:00
parent 38e7d181f9
commit 0f4b6044c4
2 changed files with 46 additions and 13 deletions

View File

@ -1,8 +1,11 @@
import os
import os.path as osp
import pickle
import cv2
import logging
import argparse
import tqdm
import motmetrics as mm
import torch
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
logger.info('save results to {}'.format(filename))
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
'''
Processes the video sequence given and provides the output of tracking result (write the results in video file)
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
The name(path) of the file for storing results.
save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames).
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
save_figures : bool
If set, individual crops of all embedded figures will be saved
show_image : bool
Option for shhowing individial frames during run-time.
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer()
results = []
frame_id = 0
for path, img, img0 in dataloader:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_id = -1
for path, img, img0 in tqdm.tqdm(dataloader):
frame_id += 1
# if frame_id % 20 == 0:
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
if os.path.exists(frame_pickle_fn):
continue
# run tracking
timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0)
online_targets = tracker.update(blob, img0)
# online targets: all tartgets that are not timed out
# frame_embeddings: the embeddings of objects visible only in the current frame
online_targets, frame_embeddings = tracker.update(blob, img0)
online_tlwhs = []
online_ids = []
for t in online_targets:
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
frame_id += 1
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
if save_img:
cv2.imwrite(base_fn+'.jpg', online_im)
if save_figures:
for i, fe in enumerate(frame_embeddings):
tlwh, curr_feat = fe
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
# print(x,y,w,h, tlwh)
crop_img = img0[y:y+h, x:x+w]
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
pickle.dump(fe, fp)
with open(frame_pickle_fn, 'wb') as fp:
pickle.dump(frame_embeddings, fp)
# save results
write_results(result_filename, results, data_type)
if result_filename is not None:
write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq))
# logger.info('start seq: {}'.format(seq))
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()

View File

@ -203,6 +203,8 @@ class JDETracker(object):
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
removed_stracks = []
frame_embeddings = []
t1 = time.time()
''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad():
@ -220,8 +222,12 @@ class JDETracker(object):
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
# Surfacing Suspicion: extract features + frame id + bbox
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
else:
detections = []
frame_embeddings = []
t2 = time.time()
# print('Forward: {} s'.format(t2-t1))
@ -346,7 +352,7 @@ class JDETracker(object):
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4))
return output_stracks
return output_stracks, frame_embeddings
def joint_stracks(tlista, tlistb):
exists = {}