Compare commits

..

No commits in common. "635eeb1e5ec72ac0697a0b9845c3dce326bdcf90" and "e9fa1e8014a9577410ce99505b871e4bc5aa1cde" have entirely different histories.

8 changed files with 18 additions and 2239 deletions

2
.gitignore vendored
View file

@ -110,5 +110,3 @@ venv.bak/
# mypy
.mypy_cache/
OUT/

View file

@ -1,5 +1,5 @@
{
"root":"/Towards-Realtime-MOT/datasets/MOT",
"root":"/home/wangzd/datasets/MOT",
"train":
{
"mot17":"./data/mot17.train",

View file

@ -49,14 +49,9 @@ def track(opt):
n_frame = 0
logger.info('Starting tracking...')
if os.path.isdir(opt.input_video):
print('Use image sequence')
dataloader = datasets.LoadImages(opt.input_video, opt.img_size)
frame_rate = 30 # hack for now; see https://motchallenge.net/data/MOT16/
else:
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)
frame_rate = dataloader.frame_rate
result_filename = os.path.join(result_root, 'results.txt')
frame_rate = dataloader.frame_rate
frame_dir = None if opt.output_format=='text' else osp.join(result_root, 'frame')
try:

View file

@ -1,20 +1,5 @@
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
FROM pytorch/pytorch:1.3-cuda10.1-cudnn7-devel
RUN apt update && apt install -y ffmpeg libsm6 libxrender-dev
RUN pip install Cython
RUN pip install opencv-python cython_bbox motmetrics numba matplotlib sklearn
RUN pip install lap
RUN pip install umap-learn
ENV NUMBA_CACHE_DIR=/tmp/numba_cache
RUN pip install bokeh
RUN pip install ipykernel
RUN pip install ipython
# Vscode bug: https://github.com/microsoft/vscode-jupyter/issues/8552
RUN pip install ipywidgets==7.7.2
#RUN pip install panel jupyter_bokeh
# for bokeh
EXPOSE 5006
CMD python -m ipykernel_launcher -f $DOCKERNEL_CONNECTION_FILE

View file

@ -1,11 +1,8 @@
import os
import os.path as osp
import pickle
import cv2
import logging
import argparse
from tqdm.auto import tqdm
import motmetrics as mm
import torch
@ -41,7 +38,7 @@ def write_results(filename, results, data_type):
logger.info('save results to {}'.format(filename))
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
'''
Processes the video sequence given and provides the output of tracking result (write the results in video file)
@ -62,9 +59,7 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
The name(path) of the file for storing results.
save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
save_figures : bool
If set, individual crops of all embedded figures will be saved
Path to the folder for storing the frames containing bounding box information (Result frames).
show_image : bool
Option for shhowing individial frames during run-time.
@ -84,20 +79,15 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer()
results = []
frame_id = -1
for path, img, img0 in tqdm(dataloader):
frame_id += 1
# if frame_id % 20 == 0:
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
if os.path.exists(frame_pickle_fn):
continue
frame_id = 0
for path, img, img0 in dataloader:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
# run tracking
timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0)
# online targets: all tartgets that are not timed out
# frame_embeddings: the embeddings of objects visible only in the current frame
online_targets, frame_embeddings = tracker.update(blob, img0)
online_targets = tracker.update(blob, img0)
online_tlwhs = []
online_ids = []
for t in online_targets:
@ -116,27 +106,10 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_im
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
if save_img:
cv2.imwrite(base_fn+'.jpg', online_im)
if save_figures:
for i, fe in enumerate(frame_embeddings):
tlwh, curr_feat = fe
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
# print(x,y,w,h, tlwh)
crop_img = img0[y:y+h, x:x+w]
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
pickle.dump(fe, fp)
with open(frame_pickle_fn, 'wb') as fp:
pickle.dump(frame_embeddings, fp)
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
frame_id += 1
# save results
if result_filename is not None:
write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls
@ -158,7 +131,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
# logger.info('start seq: {}'.format(seq))
logger.info('start seq: {}'.format(seq))
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()

View file

@ -203,8 +203,6 @@ class JDETracker(object):
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
removed_stracks = []
frame_embeddings = []
t1 = time.time()
''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad():
@ -222,12 +220,8 @@ class JDETracker(object):
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
# Surfacing Suspicion: extract features + frame id + bbox
frame_embeddings = [[track.tlwh, track.curr_feat] for track in detections]
else:
detections = []
frame_embeddings = []
t2 = time.time()
# print('Forward: {} s'.format(t2-t1))
@ -352,7 +346,7 @@ class JDETracker(object):
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4))
return output_stracks, frame_embeddings
return output_stracks
def joint_stracks(tlista, tlistb):
exists = {}

View file

@ -10,7 +10,7 @@ def get_logger(name='root'):
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
return logger

File diff suppressed because one or more lines are too long