Compare commits
10 commits
e9fa1e8014
...
635eeb1e5e
Author | SHA1 | Date | |
---|---|---|---|
|
635eeb1e5e | ||
|
913d67e019 | ||
|
291504263c | ||
|
ecbe041041 | ||
|
5deac894a5 | ||
|
0f4b6044c4 | ||
|
38e7d181f9 | ||
|
b9a31cfd29 | ||
|
9aa9e9c709 | ||
|
abf8085ab7 |
8 changed files with 2239 additions and 18 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -110,3 +110,5 @@ venv.bak/
|
|||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
OUT/
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"root":"/home/wangzd/datasets/MOT",
|
||||
"root":"/Towards-Realtime-MOT/datasets/MOT",
|
||||
"train":
|
||||
{
|
||||
"mot17":"./data/mot17.train",
|
||||
|
|
9
demo.py
9
demo.py
|
@ -49,9 +49,14 @@ def track(opt):
|
|||
n_frame = 0
|
||||
|
||||
logger.info('Starting tracking...')
|
||||
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
|
||||
if os.path.isdir(opt.input_video):
|
||||
print('Use image sequence')
|
||||
dataloader = datasets.LoadImages(opt.input_video, opt.img_size)
|
||||
frame_rate = 30 # hack for now; see https://motchallenge.net/data/MOT16/
|
||||
else:
|
||||
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
|
||||
frame_rate = dataloader.frame_rate
|
||||
result_filename = os.path.join(result_root, 'results.txt')
|
||||
frame_rate = dataloader.frame_rate
|
||||
|
||||
frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame')
|
||||
try:
|
||||
|
|
|
@ -1,5 +1,20 @@
|
|||
FROM pytorch/pytorch:1.3-cuda10.1-cudnn7-devel
|
||||
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
|
||||
|
||||
RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev
|
||||
RUN pip install Cython
|
||||
RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn
|
||||
RUN pip install lap
|
||||
RUN pip install umap-learn
|
||||
ENV NUMBA_CACHE_DIR=/tmp/numba_cache
|
||||
RUN pip install bokeh
|
||||
RUN pip install ipykernel
|
||||
RUN pip install ipython
|
||||
# Vscode bug: https://github.com/microsoft/vscode-jupyter/issues/8552
|
||||
RUN pip install ipywidgets==7.7.2
|
||||
#RUN pip install panel jupyter_bokeh
|
||||
|
||||
|
||||
# for bokeh
|
||||
EXPOSE 5006
|
||||
|
||||
CMD python -m ipykernel_launcher -f $DOCKERNEL_CONNECTION_FILE
|
||||
|
|
51
track.py
51
track.py
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import cv2
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
import motmetrics as mm
|
||||
|
||||
import torch
|
||||
|
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
|
|||
logger.info('save results to {}'.format(filename))
|
||||
|
||||
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
|
||||
'''
|
||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||
|
||||
|
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
The name(path) of the file for storing results.
|
||||
|
||||
save_dir : String
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
|
||||
save_figures : bool
|
||||
If set, individual crops of all embedded figures will be saved
|
||||
|
||||
show_image : bool
|
||||
Option for shhowing individial frames during run-time.
|
||||
|
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||
timer = Timer()
|
||||
results = []
|
||||
frame_id = 0
|
||||
for path, img, img0 in dataloader:
|
||||
if frame_id % 20 == 0:
|
||||
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||
|
||||
frame_id = -1
|
||||
for path, img, img0 in tqdm(dataloader):
|
||||
frame_id += 1
|
||||
# if frame_id % 20 == 0:
|
||||
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
|
||||
if os.path.exists(frame_pickle_fn):
|
||||
continue
|
||||
# run tracking
|
||||
timer.tic()
|
||||
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
||||
online_targets = tracker.update(blob, img0)
|
||||
# online targets: all tartgets that are not timed out
|
||||
# frame_embeddings: the embeddings of objects visible only in the current frame
|
||||
online_targets, frame_embeddings = tracker.update(blob, img0)
|
||||
online_tlwhs = []
|
||||
online_ids = []
|
||||
for t in online_targets:
|
||||
|
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
|||
if show_image:
|
||||
cv2.imshow('online_im', online_im)
|
||||
if save_dir is not None:
|
||||
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
||||
frame_id += 1
|
||||
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
|
||||
if save_img:
|
||||
cv2.imwrite(base_fn+'.jpg', online_im)
|
||||
if save_figures:
|
||||
for i, fe in enumerate(frame_embeddings):
|
||||
tlwh, curr_feat = fe
|
||||
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
|
||||
# print(x,y,w,h, tlwh)
|
||||
crop_img = img0[y:y+h, x:x+w]
|
||||
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
|
||||
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
|
||||
pickle.dump(fe, fp)
|
||||
|
||||
with open(frame_pickle_fn, 'wb') as fp:
|
||||
pickle.dump(frame_embeddings, fp)
|
||||
|
||||
|
||||
# save results
|
||||
write_results(result_filename, results, data_type)
|
||||
if result_filename is not None:
|
||||
write_results(result_filename, results, data_type)
|
||||
|
||||
return frame_id, timer.average_time, timer.calls
|
||||
|
||||
|
||||
|
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
|
|||
for seq in seqs:
|
||||
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
||||
|
||||
logger.info('start seq: {}'.format(seq))
|
||||
# logger.info('start seq: {}'.format(seq))
|
||||
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
||||
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
||||
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
||||
|
|
|
@ -203,6 +203,8 @@ class JDETracker(object):
|
|||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||
removed_stracks = []
|
||||
|
||||
frame_embeddings = []
|
||||
|
||||
t1 = time.time()
|
||||
''' Step 1: Network forward, get detections & embeddings'''
|
||||
with torch.no_grad():
|
||||
|
@ -220,8 +222,12 @@ class JDETracker(object):
|
|||
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||
# Surfacing Suspicion: extract features + frame id + bbox
|
||||
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
|
||||
|
||||
else:
|
||||
detections = []
|
||||
frame_embeddings = []
|
||||
|
||||
t2 = time.time()
|
||||
# print('Forward: {} s'.format(t2-t1))
|
||||
|
@ -346,7 +352,7 @@ class JDETracker(object):
|
|||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||
# print('Final {} s'.format(t5-t4))
|
||||
return output_stracks
|
||||
return output_stracks, frame_embeddings
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
|
|
|
@ -10,7 +10,7 @@ def get_logger(name='root'):
|
|||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
|
2166
visualise_embeddings.ipynb
Normal file
2166
visualise_embeddings.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue