Towards-Realtime-MOT/track.py

244 lines
9.8 KiB
Python
Raw Normal View History

import os
import os.path as osp
2023-04-05 17:17:15 +02:00
import pickle
import cv2
import logging
import argparse
2023-04-05 17:17:15 +02:00
2023-04-13 13:11:58 +02:00
from tqdm.auto import tqdm
import motmetrics as mm
import torch
from tracker.multitracker import JDETracker
from utils import visualization as vis
from utils.log import logger
from utils.timer import Timer
from utils.evaluation import Evaluator
from utils.parse_config import parse_model_cfg
import utils.datasets as datasets
from utils.utils import *
def write_results(filename, results, data_type):
if data_type == 'mot':
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, track_id in zip(tlwhs, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
f.write(line)
logger.info('save results to {}'.format(filename))
2023-04-05 17:17:15 +02:00
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, save_img=False, save_figures=False, show_image=True, frame_rate=30):
'''
Processes the video sequence given and provides the output of tracking result (write the results in video file)
It uses JDE model for getting information about the online targets present.
Parameters
----------
opt : Namespace
Contains information passed as commandline arguments.
dataloader : LoadVideo
Instance of LoadVideo class used for fetching the image sequence and associated data.
data_type : String
Type of dataset corresponding(similar) to the given video.
result_filename : String
The name(path) of the file for storing results.
save_dir : String
2023-04-05 17:17:15 +02:00
Path to the folder for storing the frames containing bounding box information (Result frames). If given, featuers will be save there as pickle
save_figures : bool
If set, individual crops of all embedded figures will be saved
show_image : bool
Option for shhowing individial frames during run-time.
frame_rate : int
Frame-rate of the given video.
Returns
-------
(Returns are not significant here)
frame_id : int
Sequence number of the last sequence
'''
if save_dir:
mkdir_if_missing(save_dir)
tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer()
results = []
2023-04-05 17:17:15 +02:00
frame_id = -1
2023-04-13 13:11:58 +02:00
for path, img, img0 in tqdm(dataloader):
2023-04-05 17:17:15 +02:00
frame_id += 1
# if frame_id % 20 == 0:
# logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
frame_pickle_fn = os.path.join(save_dir, f'{frame_id:05d}.pcl')
if os.path.exists(frame_pickle_fn):
continue
# run tracking
timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0)
2023-04-05 17:17:15 +02:00
# online targets: all tartgets that are not timed out
# frame_embeddings: the embeddings of objects visible only in the current frame
online_targets, frame_embeddings = tracker.update(blob, img0)
online_tlwhs = []
online_ids = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_ids))
if show_image or save_dir is not None:
online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id,
fps=1. / timer.average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
2023-04-05 17:17:15 +02:00
base_fn = os.path.join(save_dir, '{:05d}'.format(frame_id))
if save_img:
cv2.imwrite(base_fn+'.jpg', online_im)
if save_figures:
for i, fe in enumerate(frame_embeddings):
tlwh, curr_feat = fe
x,y,w,h = round(tlwh[0]), round(tlwh[1]), round(tlwh[2]), round(tlwh[3])
# print(x,y,w,h, tlwh)
crop_img = img0[y:y+h, x:x+w]
cv2.imwrite(f'{base_fn}-{i}.jpg', crop_img)
with open(os.path.join(save_dir, f'{frame_id:05d}-{i}.pcl'), 'wb') as fp:
pickle.dump(fe, fp)
with open(frame_pickle_fn, 'wb') as fp:
pickle.dump(frame_embeddings, fp)
# save results
2023-04-05 17:17:15 +02:00
if result_filename is not None:
write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls
def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), exp_name='demo',
save_images=False, save_videos=False, show_image=True):
logger.setLevel(logging.INFO)
result_root = os.path.join(data_root, '..', 'results', exp_name)
mkdir_if_missing(result_root)
data_type = 'mot'
# Read config
cfg_dict = parse_model_cfg(opt.cfg)
opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# run tracking
accs = []
n_frame = 0
timer_avgs, timer_calls = [], []
for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
2023-04-05 17:17:15 +02:00
# logger.info('start seq: {}'.format(seq))
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
frame_rate = int(meta_info[meta_info.find('frameRate')+10:meta_info.find('\nseqLength')])
nf, ta, tc = eval_seq(opt, dataloader, data_type, result_filename,
save_dir=output_dir, show_image=show_image, frame_rate=frame_rate)
n_frame += nf
timer_avgs.append(ta)
timer_calls.append(tc)
# eval
logger.info('Evaluate seq: {}'.format(seq))
evaluator = Evaluator(data_root, seq, data_type)
accs.append(evaluator.eval_file(result_filename))
if save_videos:
output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
os.system(cmd_str)
timer_avgs = np.asarray(timer_avgs)
timer_calls = np.asarray(timer_calls)
all_time = np.dot(timer_avgs, timer_calls)
avg_time = all_time / np.sum(timer_calls)
logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))
# get summary
metrics = mm.metrics.motchallenge_metrics
mh = mm.metrics.create()
summary = Evaluator.get_summary(accs, seqs, metrics)
strsummary = mm.io.render_summary(
summary,
formatters=mh.formatters,
namemap=mm.io.motchallenge_metric_names
)
print(strsummary)
Evaluator.save_summary(summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='track.py')
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes')
parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
parser.add_argument('--save-images', action='store_true', help='save tracking results (image)')
parser.add_argument('--save-videos', action='store_true', help='save tracking results (video)')
opt = parser.parse_args()
print(opt, end='\n\n')
if not opt.test_mot16:
seqs_str = '''MOT17-02-SDP
MOT17-04-SDP
MOT17-05-SDP
MOT17-09-SDP
MOT17-10-SDP
MOT17-11-SDP
MOT17-13-SDP
'''
data_root = '/home/wangzd/datasets/MOT/MOT17/images/train'
else:
seqs_str = '''MOT16-01
MOT16-03
MOT16-06
MOT16-07
MOT16-08
MOT16-12
MOT16-14'''
data_root = '/home/wangzd/datasets/MOT/MOT16/images/test'
seqs = [seq.strip() for seq in seqs_str.split()]
main(opt,
data_root=data_root,
seqs=seqs,
exp_name=opt.weights.split('/')[-2],
show_image=False,
save_images=opt.save_images,
save_videos=opt.save_videos)