Store embeddings

This commit is contained in:
Ruben van de Ven 2023-04-05 17:17:15 +02:00
parent 38e7d181f9
commit 0f4b6044c4
2 changed files with 46 additions and 13 deletions

View file

@ -1,8 +1,11 @@
import os import os
import os.path as osp import os.path as osp
import pickle
import cv2 import cv2
import logging import logging
import argparse import argparse
import tqdm
import motmetrics as mm import motmetrics as mm
import torch import torch
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
logger.info('save results to {}'.format(filename)) logger.info('save results to {}'.format(filename))
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30): def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
''' '''
Processes the video sequence given and provides the output of tracking result (write the results in video file) Processes the video sequence given and provides the output of tracking result (write the results in video file)
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
The name(path) of the file for storing results. The name(path) of the file for storing results.
save_dir : String save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames). Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
save_figures : bool
If set, individual crops of all embedded figures will be saved
show_image : bool show_image : bool
Option for shhowing individial frames during run-time. Option for shhowing individial frames during run-time.
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
tracker = JDETracker(opt, frame_rate=frame_rate) tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer() timer = Timer()
results = [] results = []
frame_id = 0 frame_id = -1
for path, img, img0 in dataloader: for path, img, img0 in tqdm.tqdm(dataloader):
if frame_id % 20 == 0: frame_id += 1
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time))) # if frame_id % 20 == 0:
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
if os.path.exists(frame_pickle_fn):
continue
# run tracking # run tracking
timer.tic() timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0) blob = torch.from_numpy(img).cuda().unsqueeze(0)
online_targets = tracker.update(blob, img0) # online targets: all tartgets that are not timed out
# frame_embeddings: the embeddings of objects visible only in the current frame
online_targets, frame_embeddings = tracker.update(blob, img0)
online_tlwhs = [] online_tlwhs = []
online_ids = [] online_ids = []
for t in online_targets: for t in online_targets:
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
if show_image: if show_image:
cv2.imshow('online_im', online_im) cv2.imshow('online_im', online_im)
if save_dir is not None: if save_dir is not None:
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
frame_id += 1 if save_img:
cv2.imwrite(base_fn+'.jpg', online_im)
if save_figures:
for i, fe in enumerate(frame_embeddings):
tlwh, curr_feat = fe
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
# print(x,y,w,h, tlwh)
crop_img = img0[y:y+h, x:x+w]
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
pickle.dump(fe, fp)
with open(frame_pickle_fn, 'wb') as fp:
pickle.dump(frame_embeddings, fp)
# save results # save results
if result_filename is not None:
write_results(result_filename, results, data_type) write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls return frame_id, timer.average_time, timer.calls
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
for seq in seqs: for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq)) # logger.info('start seq: {}'.format(seq))
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size) dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq)) result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read() meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()

View file

@ -203,6 +203,8 @@ class JDETracker(object):
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing) lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
removed_stracks = [] removed_stracks = []
frame_embeddings = []
t1 = time.time() t1 = time.time()
''' Step 1: Network forward, get detections & embeddings''' ''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad(): with torch.no_grad():
@ -220,8 +222,12 @@ class JDETracker(object):
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])] (tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
# Surfacing Suspicion: extract features + frame id + bbox
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
else: else:
detections = [] detections = []
frame_embeddings = []
t2 = time.time() t2 = time.time()
# print('Forward: {} s'.format(t2-t1)) # print('Forward: {} s'.format(t2-t1))
@ -346,7 +352,7 @@ class JDETracker(object):
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4)) # print('Final {} s'.format(t5-t4))
return output_stracks return output_stracks, frame_embeddings
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):
exists = {} exists = {}