Compare commits

..

No commits in common. "635eeb1e5ec72ac0697a0b9845c3dce326bdcf90" and "e9fa1e8014a9577410ce99505b871e4bc5aa1cde" have entirely different histories.

8 changed files with 18 additions and 2239 deletions

2
.gitignore vendored
View file

@ -110,5 +110,3 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
OUT/

View file

@ -1,5 +1,5 @@
{ {
"root":"/Towards-Realtime-MOT/datasets/MOT", "root":"/home/wangzd/datasets/MOT",
"train": "train":
{ {
"mot17":"./data/mot17.train", "mot17":"./data/mot17.train",

View file

@ -49,14 +49,9 @@ def track(opt):
n_frame = 0 n_frame = 0
logger.info('Starting tracking...') logger.info('Starting tracking...')
if os.path.isdir(opt.input_video): dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
print('Use image sequence')
dataloader = datasets.LoadImages(opt.input_video, opt.img_size)
frame_rate = 30 # hack for now; see https://motchallenge.net/data/MOT16/
else:
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
frame_rate = dataloader.frame_rate
result_filename = os.path.join(result_root, 'results.txt') result_filename = os.path.join(result_root, 'results.txt')
frame_rate = dataloader.frame_rate
frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame') frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame')
try: try:

View file

@ -1,20 +1,5 @@
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel FROM pytorch/pytorch:1.3-cuda10.1-cudnn7-devel
RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev
RUN pip install Cython RUN pip install Cython
RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn
RUN pip install lap
RUN pip install umap-learn
ENV NUMBA_CACHE_DIR=/tmp/numba_cache
RUN pip install bokeh
RUN pip install ipykernel
RUN pip install ipython
# Vscode bug: https://github.com/microsoft/vscode-jupyter/issues/8552
RUN pip install ipywidgets==7.7.2
#RUN pip install panel jupyter_bokeh
# for bokeh
EXPOSE 5006
CMD python -m ipykernel_launcher -f $DOCKERNEL_CONNECTION_FILE

View file

@ -1,11 +1,8 @@
import os import os
import os.path as osp import os.path as osp
import pickle
import cv2 import cv2
import logging import logging
import argparse import argparse
from tqdm.auto import tqdm
import motmetrics as mm import motmetrics as mm
import torch import torch
@ -41,7 +38,7 @@ def write_results(filename, results, data_type):
logger.info('save results to {}'.format(filename)) logger.info('save results to {}'.format(filename))
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30): def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
''' '''
Processes the video sequence given and provides the output of tracking result (write the results in video file) Processes the video sequence given and provides the output of tracking result (write the results in video file)
@ -62,9 +59,7 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
The name(path) of the file for storing results. The name(path) of the file for storing results.
save_dir : String save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle Path to the folder for storing the frames containing bounding box information (Result frames).
save_figures : bool
If set, individual crops of all embedded figures will be saved
show_image : bool show_image : bool
Option for shhowing individial frames during run-time. Option for shhowing individial frames during run-time.
@ -84,20 +79,15 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
tracker = JDETracker(opt, frame_rate=frame_rate) tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer() timer = Timer()
results = [] results = []
frame_id = -1 frame_id = 0
for path, img, img0 in tqdm(dataloader): for path, img, img0 in dataloader:
frame_id += 1 if frame_id % 20 == 0:
# if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
if os.path.exists(frame_pickle_fn):
continue
# run tracking # run tracking
timer.tic() timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0) blob = torch.from_numpy(img).cuda().unsqueeze(0)
# online targets: all tartgets that are not timed out online_targets = tracker.update(blob, img0)
# frame_embeddings: the embeddings of objects visible only in the current frame
online_targets, frame_embeddings = tracker.update(blob, img0)
online_tlwhs = [] online_tlwhs = []
online_ids = [] online_ids = []
for t in online_targets: for t in online_targets:
@ -116,27 +106,10 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
if show_image: if show_image:
cv2.imshow('online_im', online_im) cv2.imshow('online_im', online_im)
if save_dir is not None: if save_dir is not None:
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id)) cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
if save_img: frame_id += 1
cv2.imwrite(base_fn+'.jpg', online_im)
if save_figures:
for i, fe in enumerate(frame_embeddings):
tlwh, curr_feat = fe
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
# print(x,y,w,h, tlwh)
crop_img = img0[y:y+h, x:x+w]
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
pickle.dump(fe, fp)
with open(frame_pickle_fn, 'wb') as fp:
pickle.dump(frame_embeddings, fp)
# save results # save results
if result_filename is not None: write_results(result_filename, results, data_type)
write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls return frame_id, timer.average_time, timer.calls
@ -158,7 +131,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
for seq in seqs: for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
# logger.info('start seq: {}'.format(seq)) logger.info('start seq: {}'.format(seq))
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size) dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq)) result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read() meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()

View file

@ -203,8 +203,6 @@ class JDETracker(object):
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing) lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
removed_stracks = [] removed_stracks = []
frame_embeddings = []
t1 = time.time() t1 = time.time()
''' Step 1: Network forward, get detections & embeddings''' ''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad(): with torch.no_grad():
@ -222,12 +220,8 @@ class JDETracker(object):
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])] (tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
# Surfacing Suspicion: extract features + frame id + bbox
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
else: else:
detections = [] detections = []
frame_embeddings = []
t2 = time.time() t2 = time.time()
# print('Forward: {} s'.format(t2-t1)) # print('Forward: {} s'.format(t2-t1))
@ -352,7 +346,7 @@ class JDETracker(object):
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4)) # print('Final {} s'.format(t5-t4))
return output_stracks, frame_embeddings return output_stracks
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):
exists = {} exists = {}

View file

@ -10,7 +10,7 @@ def get_logger(name='root'):
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(logging.INFO) logger.setLevel(logging.DEBUG)
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger

File diff suppressed because one or more lines are too long