Compare commits
10 commits
e9fa1e8014
...
635eeb1e5e
Author | SHA1 | Date | |
---|---|---|---|
|
635eeb1e5e | ||
|
913d67e019 | ||
|
291504263c | ||
|
ecbe041041 | ||
|
5deac894a5 | ||
|
0f4b6044c4 | ||
|
38e7d181f9 | ||
|
b9a31cfd29 | ||
|
9aa9e9c709 | ||
|
abf8085ab7 |
8 changed files with 2239 additions and 18 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -110,3 +110,5 @@ venv.bak/
|
||||||
|
|
||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
|
OUT/
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"root":"/home/wangzd/datasets/MOT",
|
"root":"/Towards-Realtime-MOT/datasets/MOT",
|
||||||
"train":
|
"train":
|
||||||
{
|
{
|
||||||
"mot17":"./data/mot17.train",
|
"mot17":"./data/mot17.train",
|
||||||
|
|
9
demo.py
9
demo.py
|
@ -49,9 +49,14 @@ def track(opt):
|
||||||
n_frame = 0
|
n_frame = 0
|
||||||
|
|
||||||
logger.info('Starting tracking...')
|
logger.info('Starting tracking...')
|
||||||
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
|
if os.path.isdir(opt.input_video):
|
||||||
|
print('Use image sequence')
|
||||||
|
dataloader = datasets.LoadImages(opt.input_video, opt.img_size)
|
||||||
|
frame_rate = 30 # hack for now; see https://motchallenge.net/data/MOT16/
|
||||||
|
else:
|
||||||
|
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
|
||||||
|
frame_rate = dataloader.frame_rate
|
||||||
result_filename = os.path.join(result_root, 'results.txt')
|
result_filename = os.path.join(result_root, 'results.txt')
|
||||||
frame_rate = dataloader.frame_rate
|
|
||||||
|
|
||||||
frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame')
|
frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame')
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
FROM pytorch/pytorch:1.3-cuda10.1-cudnn7-devel
|
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
|
||||||
|
|
||||||
RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev
|
RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev
|
||||||
RUN pip install Cython
|
RUN pip install Cython
|
||||||
RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn
|
RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn
|
||||||
|
RUN pip install lap
|
||||||
|
RUN pip install umap-learn
|
||||||
|
ENV NUMBA_CACHE_DIR=/tmp/numba_cache
|
||||||
|
RUN pip install bokeh
|
||||||
|
RUN pip install ipykernel
|
||||||
|
RUN pip install ipython
|
||||||
|
# Vscode bug: https://github.com/microsoft/vscode-jupyter/issues/8552
|
||||||
|
RUN pip install ipywidgets==7.7.2
|
||||||
|
#RUN pip install panel jupyter_bokeh
|
||||||
|
|
||||||
|
|
||||||
|
# for bokeh
|
||||||
|
EXPOSE 5006
|
||||||
|
|
||||||
|
CMD python -m ipykernel_launcher -f $DOCKERNEL_CONNECTION_FILE
|
||||||
|
|
51
track.py
51
track.py
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import pickle
|
||||||
import cv2
|
import cv2
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
import motmetrics as mm
|
import motmetrics as mm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -38,7 +41,7 @@ def write_results(filename, results, data_type):
|
||||||
logger.info('save results to {}'.format(filename))
|
logger.info('save results to {}'.format(filename))
|
||||||
|
|
||||||
|
|
||||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
|
||||||
'''
|
'''
|
||||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||||
|
|
||||||
|
@ -59,7 +62,9 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
The name(path) of the file for storing results.
|
The name(path) of the file for storing results.
|
||||||
|
|
||||||
save_dir : String
|
save_dir : String
|
||||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
|
||||||
|
save_figures : bool
|
||||||
|
If set, individual crops of all embedded figures will be saved
|
||||||
|
|
||||||
show_image : bool
|
show_image : bool
|
||||||
Option for shhowing individial frames during run-time.
|
Option for shhowing individial frames during run-time.
|
||||||
|
@ -79,15 +84,20 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
results = []
|
results = []
|
||||||
frame_id = 0
|
frame_id = -1
|
||||||
for path, img, img0 in dataloader:
|
for path, img, img0 in tqdm(dataloader):
|
||||||
if frame_id % 20 == 0:
|
frame_id += 1
|
||||||
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
# if frame_id % 20 == 0:
|
||||||
|
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
||||||
|
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
|
||||||
|
if os.path.exists(frame_pickle_fn):
|
||||||
|
continue
|
||||||
# run tracking
|
# run tracking
|
||||||
timer.tic()
|
timer.tic()
|
||||||
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
||||||
online_targets = tracker.update(blob, img0)
|
# online targets: all tartgets that are not timed out
|
||||||
|
# frame_embeddings: the embeddings of objects visible only in the current frame
|
||||||
|
online_targets, frame_embeddings = tracker.update(blob, img0)
|
||||||
online_tlwhs = []
|
online_tlwhs = []
|
||||||
online_ids = []
|
online_ids = []
|
||||||
for t in online_targets:
|
for t in online_targets:
|
||||||
|
@ -106,10 +116,27 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
if show_image:
|
if show_image:
|
||||||
cv2.imshow('online_im', online_im)
|
cv2.imshow('online_im', online_im)
|
||||||
if save_dir is not None:
|
if save_dir is not None:
|
||||||
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
|
||||||
frame_id += 1
|
if save_img:
|
||||||
|
cv2.imwrite(base_fn+'.jpg', online_im)
|
||||||
|
if save_figures:
|
||||||
|
for i, fe in enumerate(frame_embeddings):
|
||||||
|
tlwh, curr_feat = fe
|
||||||
|
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
|
||||||
|
# print(x,y,w,h, tlwh)
|
||||||
|
crop_img = img0[y:y+h, x:x+w]
|
||||||
|
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
|
||||||
|
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
|
||||||
|
pickle.dump(fe, fp)
|
||||||
|
|
||||||
|
with open(frame_pickle_fn, 'wb') as fp:
|
||||||
|
pickle.dump(frame_embeddings, fp)
|
||||||
|
|
||||||
|
|
||||||
# save results
|
# save results
|
||||||
write_results(result_filename, results, data_type)
|
if result_filename is not None:
|
||||||
|
write_results(result_filename, results, data_type)
|
||||||
|
|
||||||
return frame_id, timer.average_time, timer.calls
|
return frame_id, timer.average_time, timer.calls
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,7 +158,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
||||||
|
|
||||||
logger.info('start seq: {}'.format(seq))
|
# logger.info('start seq: {}'.format(seq))
|
||||||
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
||||||
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
||||||
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
||||||
|
|
|
@ -203,6 +203,8 @@ class JDETracker(object):
|
||||||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||||
removed_stracks = []
|
removed_stracks = []
|
||||||
|
|
||||||
|
frame_embeddings = []
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
''' Step 1: Network forward, get detections & embeddings'''
|
''' Step 1: Network forward, get detections & embeddings'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -220,8 +222,12 @@ class JDETracker(object):
|
||||||
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||||
|
# Surfacing Suspicion: extract features + frame id + bbox
|
||||||
|
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
detections = []
|
detections = []
|
||||||
|
frame_embeddings = []
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
# print('Forward: {} s'.format(t2-t1))
|
# print('Forward: {} s'.format(t2-t1))
|
||||||
|
@ -346,7 +352,7 @@ class JDETracker(object):
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||||
# print('Final {} s'.format(t5-t4))
|
# print('Final {} s'.format(t5-t4))
|
||||||
return output_stracks
|
return output_stracks, frame_embeddings
|
||||||
|
|
||||||
def joint_stracks(tlista, tlistb):
|
def joint_stracks(tlista, tlistb):
|
||||||
exists = {}
|
exists = {}
|
||||||
|
|
|
@ -10,7 +10,7 @@ def get_logger(name='root'):
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.INFO)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
2166
visualise_embeddings.ipynb
Normal file
2166
visualise_embeddings.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue