24f351d1b5
* Added documentation * Added docstrings and comments * Removed unused imports * Removed unused imports * Added functionality of saving checkpoints during training process * Update train.py * Update multitracker.py
216 lines
8.5 KiB
Python
216 lines
8.5 KiB
Python
import os
|
|
import os.path as osp
|
|
import cv2
|
|
import logging
|
|
import argparse
|
|
import motmetrics as mm
|
|
|
|
import torch
|
|
from tracker.multitracker import JDETracker
|
|
from utils import visualization as vis
|
|
from utils.log import logger
|
|
from utils.timer import Timer
|
|
from utils.evaluation import Evaluator
|
|
from utils.parse_config import parse_model_cfg
|
|
import utils.datasets as datasets
|
|
from utils.utils import *
|
|
|
|
|
|
def write_results(filename, results, data_type):
|
|
if data_type == 'mot':
|
|
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
|
elif data_type == 'kitti':
|
|
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
|
else:
|
|
raise ValueError(data_type)
|
|
|
|
with open(filename, 'w') as f:
|
|
for frame_id, tlwhs, track_ids in results:
|
|
if data_type == 'kitti':
|
|
frame_id -= 1
|
|
for tlwh, track_id in zip(tlwhs, track_ids):
|
|
if track_id < 0:
|
|
continue
|
|
x1, y1, w, h = tlwh
|
|
x2, y2 = x1 + w, y1 + h
|
|
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
|
|
f.write(line)
|
|
logger.info('save results to {}'.format(filename))
|
|
|
|
|
|
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
|
'''
|
|
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
|
|
|
It uses JDE model for getting information about the online targets present.
|
|
|
|
Parameters
|
|
----------
|
|
opt : Namespace
|
|
Contains information passed as commandline arguments.
|
|
|
|
dataloader : LoadVideo
|
|
Instance of LoadVideo class used for fetching the image sequence and associated data.
|
|
|
|
data_type : String
|
|
Type of dataset corresponding(similar) to the given video.
|
|
|
|
result_filename : String
|
|
The name(path) of the file for storing results.
|
|
|
|
save_dir : String
|
|
Path to the folder for storing the frames containing bounding box information (Result frames).
|
|
|
|
show_image : bool
|
|
Option for shhowing individial frames during run-time.
|
|
|
|
frame_rate : int
|
|
Frame-rate of the given video.
|
|
|
|
Returns
|
|
-------
|
|
(Returns are not significant here)
|
|
frame_id : int
|
|
Sequence number of the last sequence
|
|
'''
|
|
|
|
if save_dir:
|
|
mkdir_if_missing(save_dir)
|
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
|
timer = Timer()
|
|
results = []
|
|
frame_id = 0
|
|
for path, img, img0 in dataloader:
|
|
if frame_id % 20 == 0:
|
|
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
|
|
|
|
# run tracking
|
|
timer.tic()
|
|
blob = torch.from_numpy(img).cuda().unsqueeze(0)
|
|
online_targets = tracker.update(blob, img0)
|
|
online_tlwhs = []
|
|
online_ids = []
|
|
for t in online_targets:
|
|
tlwh = t.tlwh
|
|
tid = t.track_id
|
|
vertical = tlwh[2] / tlwh[3] > 1.6
|
|
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
|
|
online_tlwhs.append(tlwh)
|
|
online_ids.append(tid)
|
|
timer.toc()
|
|
# save results
|
|
results.append((frame_id + 1, online_tlwhs, online_ids))
|
|
if show_image or save_dir is not None:
|
|
online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id,
|
|
fps=1. / timer.average_time)
|
|
if show_image:
|
|
cv2.imshow('online_im', online_im)
|
|
if save_dir is not None:
|
|
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
|
|
frame_id += 1
|
|
# save results
|
|
write_results(result_filename, results, data_type)
|
|
return frame_id, timer.average_time, timer.calls
|
|
|
|
|
|
def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), exp_name='demo',
|
|
save_images=False, save_videos=False, show_image=True):
|
|
logger.setLevel(logging.INFO)
|
|
result_root = os.path.join(data_root, '..', 'results', exp_name)
|
|
mkdir_if_missing(result_root)
|
|
data_type = 'mot'
|
|
|
|
# Read config
|
|
cfg_dict = parse_model_cfg(opt.cfg)
|
|
opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
|
|
|
|
# run tracking
|
|
accs = []
|
|
n_frame = 0
|
|
timer_avgs, timer_calls = [], []
|
|
for seq in seqs:
|
|
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
|
|
|
logger.info('start seq: {}'.format(seq))
|
|
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
|
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
|
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
|
frame_rate = int(meta_info[meta_info.find('frameRate')+10:meta_info.find('\nseqLength')])
|
|
nf, ta, tc = eval_seq(opt, dataloader, data_type, result_filename,
|
|
save_dir=output_dir, show_image=show_image, frame_rate=frame_rate)
|
|
n_frame += nf
|
|
timer_avgs.append(ta)
|
|
timer_calls.append(tc)
|
|
|
|
# eval
|
|
logger.info('Evaluate seq: {}'.format(seq))
|
|
evaluator = Evaluator(data_root, seq, data_type)
|
|
accs.append(evaluator.eval_file(result_filename))
|
|
if save_videos:
|
|
output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
|
|
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
|
|
os.system(cmd_str)
|
|
timer_avgs = np.asarray(timer_avgs)
|
|
timer_calls = np.asarray(timer_calls)
|
|
all_time = np.dot(timer_avgs, timer_calls)
|
|
avg_time = all_time / np.sum(timer_calls)
|
|
logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))
|
|
|
|
# get summary
|
|
metrics = mm.metrics.motchallenge_metrics
|
|
mh = mm.metrics.create()
|
|
summary = Evaluator.get_summary(accs, seqs, metrics)
|
|
strsummary = mm.io.render_summary(
|
|
summary,
|
|
formatters=mh.formatters,
|
|
namemap=mm.io.motchallenge_metric_names
|
|
)
|
|
print(strsummary)
|
|
Evaluator.save_summary(summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(prog='track.py')
|
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
|
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
|
|
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
|
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
|
|
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
|
|
parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes')
|
|
parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
|
|
parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
|
|
parser.add_argument('--save-images', action='store_true', help='save tracking results (image)')
|
|
parser.add_argument('--save-videos', action='store_true', help='save tracking results (video)')
|
|
opt = parser.parse_args()
|
|
print(opt, end='\n\n')
|
|
|
|
if not opt.test_mot16:
|
|
seqs_str = '''MOT17-02-SDP
|
|
MOT17-04-SDP
|
|
MOT17-05-SDP
|
|
MOT17-09-SDP
|
|
MOT17-10-SDP
|
|
MOT17-11-SDP
|
|
MOT17-13-SDP
|
|
'''
|
|
data_root = '/home/wangzd/datasets/MOT/MOT17/images/train'
|
|
else:
|
|
seqs_str = '''MOT16-01
|
|
MOT16-03
|
|
MOT16-06
|
|
MOT16-07
|
|
MOT16-08
|
|
MOT16-12
|
|
MOT16-14'''
|
|
data_root = '/home/wangzd/datasets/MOT/MOT16/images/test'
|
|
seqs = [seq.strip() for seq in seqs_str.split()]
|
|
|
|
main(opt,
|
|
data_root=data_root,
|
|
seqs=seqs,
|
|
exp_name=opt.weights.split('/')[-2],
|
|
show_image=False,
|
|
save_images=opt.save_images,
|
|
save_videos=opt.save_videos)
|
|
|