can save video output now
This commit is contained in:
parent
d287deeca0
commit
acc0a09d0f
7 changed files with 21 additions and 821 deletions
|
@ -9,8 +9,6 @@ from utils.syncbn import SyncBN
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
|
||||||
|
|
||||||
batch_norm=SyncBN #nn.BatchNorm2d
|
batch_norm=SyncBN #nn.BatchNorm2d
|
||||||
|
|
||||||
def create_modules(module_defs):
|
def create_modules(module_defs):
|
||||||
|
@ -184,9 +182,10 @@ class YOLOLayer(nn.Module):
|
||||||
lid = self.IDLoss(logits, tids.squeeze())
|
lid = self.IDLoss(logits, tids.squeeze())
|
||||||
|
|
||||||
# Sum loss components
|
# Sum loss components
|
||||||
loss = torch.exp(-self.s_r)*lbox + torch.exp(-self.s_c)*lconf + torch.exp(-self.s_id)*lid + \
|
#loss = torch.exp(-self.s_r)*lbox + torch.exp(-self.s_c)*lconf + torch.exp(-self.s_id)*lid + \
|
||||||
(self.s_r + self.s_c + self.s_id)
|
# (self.s_r + self.s_c + self.s_id)
|
||||||
loss *= 0.5
|
#loss *= 0.5
|
||||||
|
loss = lbox + lconf + lid
|
||||||
|
|
||||||
return loss, loss.item(), lbox.item(), lconf.item(), lid.item(), nT
|
return loss, loss.item(), lbox.item(), lconf.item(), lid.item(), nT
|
||||||
|
|
||||||
|
|
25
track.py
25
track.py
|
@ -5,7 +5,7 @@ import logging
|
||||||
import argparse
|
import argparse
|
||||||
import motmetrics as mm
|
import motmetrics as mm
|
||||||
|
|
||||||
from tracker.mot_tracker_kalman import AETracker
|
from tracker.multitracker import JDETracker
|
||||||
from utils import visualization as vis
|
from utils import visualization as vis
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
from utils.timer import Timer
|
from utils.timer import Timer
|
||||||
|
@ -44,7 +44,7 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
if save_dir is not None:
|
if save_dir is not None:
|
||||||
mkdirs(save_dir)
|
mkdirs(save_dir)
|
||||||
|
|
||||||
tracker = AETracker(opt, frame_rate=frame_rate)
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
results = []
|
results = []
|
||||||
frame_id = 0
|
frame_id = 0
|
||||||
|
@ -81,8 +81,8 @@ def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_im
|
||||||
return frame_id
|
return frame_id
|
||||||
|
|
||||||
|
|
||||||
def main(opt, data_root='/data/MOT16/train', det_root=None,
|
def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), exp_name='demo',
|
||||||
seqs=('MOT16-05',), exp_name='demo', save_image=False, show_image=True):
|
save_images=False, save_videos=False, show_image=True):
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
result_root = os.path.join(data_root, '..', 'results', exp_name)
|
result_root = os.path.join(data_root, '..', 'results', exp_name)
|
||||||
mkdirs(result_root)
|
mkdirs(result_root)
|
||||||
|
@ -94,7 +94,7 @@ def main(opt, data_root='/data/MOT16/train', det_root=None,
|
||||||
n_frame = 0
|
n_frame = 0
|
||||||
timer.tic()
|
timer.tic()
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_image else None
|
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
|
||||||
|
|
||||||
logger.info('start seq: {}'.format(seq))
|
logger.info('start seq: {}'.format(seq))
|
||||||
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
|
||||||
|
@ -108,6 +108,10 @@ def main(opt, data_root='/data/MOT16/train', det_root=None,
|
||||||
logger.info('Evaluate seq: {}'.format(seq))
|
logger.info('Evaluate seq: {}'.format(seq))
|
||||||
evaluator = Evaluator(data_root, seq, data_type)
|
evaluator = Evaluator(data_root, seq, data_type)
|
||||||
accs.append(evaluator.eval_file(result_filename))
|
accs.append(evaluator.eval_file(result_filename))
|
||||||
|
if save_videos:
|
||||||
|
output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
|
||||||
|
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
|
||||||
|
os.system(cmd_str)
|
||||||
timer.toc()
|
timer.toc()
|
||||||
logger.info('Time elapsed: {}, FPS {}'.format(timer.average_time, n_frame / timer.average_time))
|
logger.info('Time elapsed: {}, FPS {}'.format(timer.average_time, n_frame / timer.average_time))
|
||||||
|
|
||||||
|
@ -131,7 +135,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--batch-size', type=int, default=8, help='size of each image batch')
|
parser.add_argument('--batch-size', type=int, default=8, help='size of each image batch')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
|
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
|
||||||
parser.add_argument('--img-size', type=int, default=(864,480), help='size of each image dimension')
|
parser.add_argument('--img-size', type=int, default=(1088, 608), help='size of each image dimension')
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
|
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
|
||||||
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
|
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
|
||||||
|
@ -139,7 +143,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--pixel-mean', type=float, default=[0,0,0], nargs='+', help='pixel mean')
|
parser.add_argument('--pixel-mean', type=float, default=[0,0,0], nargs='+', help='pixel mean')
|
||||||
parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
|
parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
|
||||||
parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
|
parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
|
||||||
parser.add_argument('--save-images', action='store_true', help='save tracking results')
|
parser.add_argument('--save-images', action='store_true', help='save tracking results (image)')
|
||||||
|
parser.add_argument('--save-videos', action='store_true', help='save tracking results (video)')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt, end='\n\n')
|
print(opt, end='\n\n')
|
||||||
|
|
||||||
|
@ -157,14 +162,14 @@ if __name__ == '__main__':
|
||||||
MOT16-08
|
MOT16-08
|
||||||
MOT16-12
|
MOT16-12
|
||||||
MOT16-14'''
|
MOT16-14'''
|
||||||
#seqs_str = 'MOT16-14'
|
|
||||||
data_root = '/home/wangzd/datasets/MOT/MOT16/test'
|
data_root = '/home/wangzd/datasets/MOT/MOT16/test'
|
||||||
seqs = [seq.strip() for seq in seqs_str.split()]
|
seqs = [seq.strip() for seq in seqs_str.split()]
|
||||||
|
|
||||||
main(opt,
|
main(opt,
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
seqs=seqs,
|
seqs=seqs,
|
||||||
exp_name='darknet53.864x480',
|
exp_name='darknet53',
|
||||||
show_image=False,
|
show_image=False,
|
||||||
save_image=opt.save_images)
|
save_images=opt.save_images,
|
||||||
|
save_videos=opt.save_videos)
|
||||||
|
|
||||||
|
|
|
@ -1,181 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from numba import jit
|
|
||||||
from collections import deque
|
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lib.utils.log import logger
|
|
||||||
from lib.tracker import matching
|
|
||||||
from lib.utils.kalman_filter import KalmanFilter
|
|
||||||
from lib.model.faster_rcnn.resnet import resnet_deploy
|
|
||||||
from lib.model.utils.config import cfg
|
|
||||||
from lib.model.rpn.bbox_transform import clip_boxes, bbox_transform_inv
|
|
||||||
from lib.model.nms.nms_wrapper import nms
|
|
||||||
|
|
||||||
from .basetrack import BaseTrack, TrackState
|
|
||||||
|
|
||||||
|
|
||||||
class STrack(BaseTrack):
|
|
||||||
|
|
||||||
def __init__(self, tlwh, score, temp_feat):
|
|
||||||
|
|
||||||
# wait activate
|
|
||||||
self._tlwh = np.asarray(tlwh, dtype=np.float)
|
|
||||||
self.is_activated = False
|
|
||||||
self.score = score
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.temp_feat = temp_feat
|
|
||||||
|
|
||||||
def activate(self, frame_id):
|
|
||||||
"""Start a new tracklet"""
|
|
||||||
self.track_id = self.next_id()
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
#self.is_activated = True
|
|
||||||
self.frame_id = frame_id
|
|
||||||
self.start_frame = frame_id
|
|
||||||
|
|
||||||
def re_activate(self, new_track, frame_id, new_id=False):
|
|
||||||
self._tlwh = new_track.tlwh
|
|
||||||
self.temp_feat = new_track.temp_feat
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
self.is_activated = True
|
|
||||||
self.frame_id = frame_id
|
|
||||||
if new_id:
|
|
||||||
self.track_id = self.next_id()
|
|
||||||
|
|
||||||
def update(self, new_track, frame_id, update_feature=True):
|
|
||||||
"""
|
|
||||||
Update a matched track
|
|
||||||
:type new_track: STrack
|
|
||||||
:type frame_id: int
|
|
||||||
:type update_feature: bool
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self.frame_id = frame_id
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len += 1
|
|
||||||
|
|
||||||
self._tlwh = new_track.tlwh
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
self.is_activated = True
|
|
||||||
|
|
||||||
self.score = new_track.score
|
|
||||||
if update_feature:
|
|
||||||
self.temp_feat = new_track.temp_feat
|
|
||||||
|
|
||||||
@property
|
|
||||||
@jit
|
|
||||||
def tlwh(self):
|
|
||||||
"""Get current position in bounding box format `(top left x, top left y,
|
|
||||||
width, height)`.
|
|
||||||
"""
|
|
||||||
return self._tlwh.copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@jit
|
|
||||||
def tlbr(self):
|
|
||||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
|
||||||
`(top left, bottom right)`.
|
|
||||||
"""
|
|
||||||
ret = self.tlwh.copy()
|
|
||||||
ret[2:] += ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@jit
|
|
||||||
def tlbr_to_tlwh(tlbr):
|
|
||||||
ret = np.asarray(tlbr).copy()
|
|
||||||
ret[2:] -= ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@jit
|
|
||||||
def tlwh_to_tlbr(tlwh):
|
|
||||||
ret = np.asarray(tlwh).copy()
|
|
||||||
ret[2:] += ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
|
||||||
|
|
||||||
|
|
||||||
class JDTracker(object):
|
|
||||||
def __init__(self, checksession=3, checkepoch=24, checkpoint=663, det_thresh=0.92, frame_rate=30):
|
|
||||||
self.classes = np.asarray(['__background__', 'pedestrian'])
|
|
||||||
|
|
||||||
self.fasterRCNN = resnet_deploy(self.classes, 101, pretrained=False, class_agnostic=False)
|
|
||||||
self.fasterRCNN.create_architecture()
|
|
||||||
|
|
||||||
input_dir = osp.join('models', 'res101', 'mot17det')
|
|
||||||
if not os.path.exists(input_dir):
|
|
||||||
raise Exception('There is no input directory for loading network from ' + input_dir)
|
|
||||||
load_name = os.path.join(input_dir,
|
|
||||||
'faster_rcnn_{}_{}_{}.pth'.format(checksession, checkepoch, checkpoint))
|
|
||||||
print("load checkpoint %s" % (load_name))
|
|
||||||
checkpoint = torch.load(load_name)
|
|
||||||
self.fasterRCNN.load_state_dict(checkpoint['model'], strict=False)
|
|
||||||
print('load model successfully!')
|
|
||||||
self.fasterRCNN.cuda()
|
|
||||||
self.fasterRCNN.eval()
|
|
||||||
|
|
||||||
self.frame_id = 0
|
|
||||||
self.det_thresh = det_thresh
|
|
||||||
self.buffer_size = int(frame_rate / 30.0 * cfg.TRACKING_BUFFER_SIZE)
|
|
||||||
self.max_time_lost = self.buffer_size
|
|
||||||
#self.fmap_buffer = deque([], maxlen=self.buffer_size)
|
|
||||||
|
|
||||||
def update(self, im_blob):
|
|
||||||
self.frame_id += 1
|
|
||||||
|
|
||||||
'''Forward'''
|
|
||||||
im_blob = im_blob.cuda()
|
|
||||||
im_info = torch.Tensor([[im_blob.shape[1], im_blob.shape[2], 1, ],]).float().cuda()
|
|
||||||
self.im_info = im_info
|
|
||||||
boxes, temp_feat, base_feat = self.predict(im_blob, im_info)
|
|
||||||
|
|
||||||
'''Detections'''
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh((t, l, b, r)), s, f) for (t, l, b, r, s), f in zip(boxes, temp_feat)]
|
|
||||||
|
|
||||||
|
|
||||||
return detections
|
|
||||||
|
|
||||||
def predict(self, im_blob, im_info):
|
|
||||||
im_blob = im_blob.permute(0,3,1,2)
|
|
||||||
# Trivial input
|
|
||||||
gt_boxes = torch.zeros(1, 1, 6).to(im_blob)
|
|
||||||
num_boxes = gt_boxes[:, :, 0].squeeze()
|
|
||||||
with torch.no_grad():
|
|
||||||
rois, cls_prob, bbox_pred, base_feat = self.fasterRCNN(im_blob, im_info, gt_boxes, num_boxes)
|
|
||||||
scores = cls_prob.data
|
|
||||||
inds_first = torch.nonzero(scores[0, :, 1] > self.det_thresh).view(-1)
|
|
||||||
if inds_first.numel() > 0:
|
|
||||||
rois = rois[:, inds_first]
|
|
||||||
scores = scores[:,inds_first]
|
|
||||||
bbox_pred = bbox_pred[:, inds_first]
|
|
||||||
|
|
||||||
refined_rois = self.fasterRCNN.bbox_refine(rois, bbox_pred, im_info)
|
|
||||||
template_feat = self.fasterRCNN.roi_pool(base_feat, refined_rois)
|
|
||||||
pred_boxes = refined_rois.data[:, :, 1:5]
|
|
||||||
|
|
||||||
cls_scores = scores[0, :, 1]
|
|
||||||
_, order = torch.sort(cls_scores, 0, True)
|
|
||||||
cls_boxes = pred_boxes[0]
|
|
||||||
cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
|
|
||||||
cls_dets = cls_dets[order]
|
|
||||||
temp_feat = template_feat[order]
|
|
||||||
keep_first = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS).view(-1).long()
|
|
||||||
cls_dets = cls_dets[keep_first]
|
|
||||||
temp_feat = temp_feat[keep_first]
|
|
||||||
output_box = cls_dets.cpu().numpy()
|
|
||||||
else:
|
|
||||||
output_box = []
|
|
||||||
temp_feat = []
|
|
||||||
return output_box, temp_feat, base_feat
|
|
||||||
|
|
|
@ -25,8 +25,6 @@ def merge_matches(m1, m2, shape):
|
||||||
return match, unmatched_O, unmatched_Q
|
return match, unmatched_O, unmatched_Q
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||||
matched_mask = (matched_cost <= thresh)
|
matched_mask = (matched_cost <= thresh)
|
||||||
|
@ -94,23 +92,6 @@ def iou_distance(atracks, btracks):
|
||||||
|
|
||||||
return cost_matrix
|
return cost_matrix
|
||||||
|
|
||||||
#def embedding_distance(tracks, detections, metric='cosine'):
|
|
||||||
# """
|
|
||||||
# :param tracks: list[STrack]
|
|
||||||
# :param detections: list[BaseTrack]
|
|
||||||
# :param metric:
|
|
||||||
# :return: cost_matrix np.ndarray
|
|
||||||
# """
|
|
||||||
#
|
|
||||||
# cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
|
|
||||||
# if cost_matrix.size == 0:
|
|
||||||
# return cost_matrix
|
|
||||||
# det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
|
|
||||||
# for i, track in enumerate(tracks):
|
|
||||||
# #cost_matrix[i, :] = np.maximum(0.0, cdist(track.features, det_features, metric).min(axis=0))
|
|
||||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.features, det_features, metric).min(axis=0))
|
|
||||||
# return cost_matrix
|
|
||||||
|
|
||||||
def embedding_distance(tracks, detections, metric='cosine'):
|
def embedding_distance(tracks, detections, metric='cosine'):
|
||||||
"""
|
"""
|
||||||
:param tracks: list[STrack]
|
:param tracks: list[STrack]
|
||||||
|
|
|
@ -1,473 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from numba import jit
|
|
||||||
from collections import deque
|
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from utils.utils import *
|
|
||||||
from utils.log import logger
|
|
||||||
from models import *
|
|
||||||
from tracker import matching
|
|
||||||
from .basetrack import BaseTrack, TrackState
|
|
||||||
|
|
||||||
|
|
||||||
class STrack(BaseTrack):
|
|
||||||
|
|
||||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
|
||||||
|
|
||||||
# wait activate
|
|
||||||
self._tlwh = np.asarray(tlwh, dtype=np.float)
|
|
||||||
self.is_activated = False
|
|
||||||
self.score = score
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.smooth_feat = None
|
|
||||||
self.update_features(temp_feat)
|
|
||||||
self.features = deque([], maxlen=buffer_size)
|
|
||||||
|
|
||||||
def update_features(self, feat):
|
|
||||||
print(1)
|
|
||||||
self.curr_feat = feat
|
|
||||||
if self.smooth_feat is None:
|
|
||||||
self.smooth_feat = feat
|
|
||||||
else:
|
|
||||||
self.smooth_feat = 0.9 *self.smooth_feat + 0.1 * feat
|
|
||||||
self.features.append(temp_feat)
|
|
||||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
|
||||||
|
|
||||||
|
|
||||||
def activate(self, frame_id):
|
|
||||||
"""Start a new tracklet"""
|
|
||||||
self.track_id = self.next_id()
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
#self.is_activated = True
|
|
||||||
self.frame_id = frame_id
|
|
||||||
self.start_frame = frame_id
|
|
||||||
|
|
||||||
def re_activate(self, new_track, frame_id, new_id=False):
|
|
||||||
self._tlwh = new_track.tlwh
|
|
||||||
#self.features.append(new_track.curr_feat)
|
|
||||||
self.update_features(new_track.curr_feat)
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len = 0
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
self.is_activated = True
|
|
||||||
self.frame_id = frame_id
|
|
||||||
if new_id:
|
|
||||||
self.track_id = self.next_id()
|
|
||||||
|
|
||||||
def update(self, new_track, frame_id, update_feature=True):
|
|
||||||
"""
|
|
||||||
Update a matched track
|
|
||||||
:type new_track: STrack
|
|
||||||
:type frame_id: int
|
|
||||||
:type update_feature: bool
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self.frame_id = frame_id
|
|
||||||
self.time_since_update = 0
|
|
||||||
self.tracklet_len += 1
|
|
||||||
|
|
||||||
self._tlwh = new_track.tlwh
|
|
||||||
self.state = TrackState.Tracked
|
|
||||||
self.is_activated = True
|
|
||||||
|
|
||||||
self.score = new_track.score
|
|
||||||
if update_feature:
|
|
||||||
#self.features.append( new_track.curr_feat)
|
|
||||||
self.update_features(new_track.curr_feat)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@jit
|
|
||||||
def tlwh(self):
|
|
||||||
"""Get current position in bounding box format `(top left x, top left y,
|
|
||||||
width, height)`.
|
|
||||||
"""
|
|
||||||
return self._tlwh.copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@jit
|
|
||||||
def tlbr(self):
|
|
||||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
|
||||||
`(top left, bottom right)`.
|
|
||||||
"""
|
|
||||||
ret = self.tlwh.copy()
|
|
||||||
ret[2:] += ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@jit
|
|
||||||
def tlbr_to_tlwh(tlbr):
|
|
||||||
ret = np.asarray(tlbr).copy()
|
|
||||||
ret[2:] -= ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@jit
|
|
||||||
def tlwh_to_tlbr(tlwh):
|
|
||||||
ret = np.asarray(tlwh).copy()
|
|
||||||
ret[2:] += ret[:2]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
|
||||||
|
|
||||||
class IOUTracker(object):
|
|
||||||
def __init__(self, opt, frame_rate=30):
|
|
||||||
self.opt = opt
|
|
||||||
self.model = Darknet(opt.cfg, opt.img_size, nID=14455)
|
|
||||||
#load_darknet_weights(self.model, opt.weights)
|
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'])
|
|
||||||
self.model.cuda().eval()
|
|
||||||
|
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
|
||||||
self.lost_stracks = [] # type: list[STrack]
|
|
||||||
self.removed_stracks = [] # type: list[STrack]
|
|
||||||
|
|
||||||
self.frame_id = 0
|
|
||||||
self.det_thresh = opt.conf_thres
|
|
||||||
self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
|
|
||||||
self.max_time_lost = self.buffer_size
|
|
||||||
#self.fmap_buffer = deque([], maxlen=self.buffer_size)
|
|
||||||
|
|
||||||
def update(self, im_blob, img0):
|
|
||||||
self.frame_id += 1
|
|
||||||
activated_starcks = []
|
|
||||||
refind_stracks = []
|
|
||||||
lost_stracks = []
|
|
||||||
removed_stracks = []
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
'''Forward'''
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.model(im_blob)
|
|
||||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
|
||||||
if len(pred) > 0:
|
|
||||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0]
|
|
||||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
|
||||||
'''Detections'''
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh((t, l, b, r)), s, None) for (t, l, b, r, s) in dets[:, :5]]
|
|
||||||
else:
|
|
||||||
detections = []
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
#print('Forward: {} s'.format(t2-t1))
|
|
||||||
|
|
||||||
|
|
||||||
'''matching for tracked targets'''
|
|
||||||
unconfirmed = []
|
|
||||||
tracked_stracks = [] # type: list[STrack]
|
|
||||||
for track in self.tracked_stracks:
|
|
||||||
if not track.is_activated:
|
|
||||||
unconfirmed.append(track)
|
|
||||||
else:
|
|
||||||
tracked_stracks.append(track)
|
|
||||||
|
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
|
||||||
#dists = self.track_matching(strack_pool, detections, base_feat)
|
|
||||||
dists = matching.iou_distance(strack_pool, detections)
|
|
||||||
#dists[np.where(iou_dists>0.4)] = 1.0
|
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
|
|
||||||
for itracked, idet in matches:
|
|
||||||
track = strack_pool[itracked]
|
|
||||||
det = detections[idet]
|
|
||||||
if track.state == TrackState.Tracked:
|
|
||||||
track.update(detections[idet], self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
else:
|
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
refind_stracks.append(track)
|
|
||||||
t3 = time.time()
|
|
||||||
#print('First match {} s'.format(t3-t2))
|
|
||||||
|
|
||||||
#'''Remained det/track, use IOU between dets and tracks to associate directly'''
|
|
||||||
#detections = [detections[i] for i in u_detection]
|
|
||||||
#r_tracked_stracks = [strack_pool[i] for i in u_track ]
|
|
||||||
#dists = matching.iou_distance(r_tracked_stracks, detections)
|
|
||||||
#matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
#for itracked, idet in matches:
|
|
||||||
# r_tracked_stracks[itracked].update(detections[idet], self.frame_id)
|
|
||||||
for it in u_track:
|
|
||||||
track = strack_pool[it]
|
|
||||||
if not track.state == TrackState.Lost:
|
|
||||||
track.mark_lost()
|
|
||||||
lost_stracks.append(track)
|
|
||||||
|
|
||||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
|
||||||
detections = [detections[i] for i in u_detection]
|
|
||||||
dists = matching.iou_distance(unconfirmed, detections)
|
|
||||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
for itracked, idet in matches:
|
|
||||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
|
||||||
for it in u_unconfirmed:
|
|
||||||
track = unconfirmed[it]
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
|
|
||||||
"""step 4: init new stracks"""
|
|
||||||
for inew in u_detection:
|
|
||||||
track = detections[inew]
|
|
||||||
if track.score < self.det_thresh:
|
|
||||||
continue
|
|
||||||
track.activate(self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
|
|
||||||
"""step 6: update state"""
|
|
||||||
for track in self.lost_stracks:
|
|
||||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
t4 = time.time()
|
|
||||||
#print('Ramained match {} s'.format(t4-t3))
|
|
||||||
|
|
||||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
|
||||||
#self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
|
||||||
self.lost_stracks.extend(lost_stracks)
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
|
||||||
self.removed_stracks.extend(removed_stracks)
|
|
||||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
|
||||||
|
|
||||||
# get scores of lost tracks
|
|
||||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
|
||||||
|
|
||||||
logger.debug('===========Frame {}=========='.format(self.frame_id))
|
|
||||||
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
|
|
||||||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
|
||||||
t5 = time.time()
|
|
||||||
#print('Final {} s'.format(t5-t4))
|
|
||||||
return output_stracks
|
|
||||||
|
|
||||||
|
|
||||||
class AETracker(object):
|
|
||||||
def __init__(self, opt, frame_rate=30):
|
|
||||||
self.opt = opt
|
|
||||||
self.model = Darknet(opt.cfg, opt.img_size, nID=14455)
|
|
||||||
# load_darknet_weights(self.model, opt.weights)
|
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'])
|
|
||||||
self.model.cuda().eval()
|
|
||||||
|
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
|
||||||
self.lost_stracks = [] # type: list[STrack]
|
|
||||||
self.removed_stracks = [] # type: list[STrack]
|
|
||||||
|
|
||||||
self.frame_id = 0
|
|
||||||
self.det_thresh = opt.conf_thres
|
|
||||||
self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
|
|
||||||
self.max_time_lost = self.buffer_size
|
|
||||||
|
|
||||||
def update(self, im_blob, img0):
|
|
||||||
self.frame_id += 1
|
|
||||||
activated_starcks = []
|
|
||||||
refind_stracks = []
|
|
||||||
lost_stracks = []
|
|
||||||
removed_stracks = []
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
'''Forward'''
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.model(im_blob)
|
|
||||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
|
||||||
if len(pred) > 0:
|
|
||||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
|
|
||||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
|
||||||
'''Detections'''
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
|
||||||
(tlbrs, f) in zip(dets[:, :5], dets[:, -self.model.emb_dim:])]
|
|
||||||
else:
|
|
||||||
detections = []
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
# print('Forward: {} s'.format(t2-t1))
|
|
||||||
|
|
||||||
'''matching for tracked targets'''
|
|
||||||
unconfirmed = []
|
|
||||||
tracked_stracks = [] # type: list[STrack]
|
|
||||||
for track in self.tracked_stracks:
|
|
||||||
if not track.is_activated:
|
|
||||||
unconfirmed.append(track)
|
|
||||||
else:
|
|
||||||
tracked_stracks.append(track)
|
|
||||||
|
|
||||||
|
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
|
||||||
#strack_pool = tracked_stracks
|
|
||||||
dists = matching.embedding_distance(strack_pool, detections)
|
|
||||||
iou_dists = matching.iou_distance(strack_pool, detections)
|
|
||||||
dists[np.where(iou_dists>0.99)] = 1.0
|
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
|
|
||||||
for itracked, idet in matches:
|
|
||||||
track = strack_pool[itracked]
|
|
||||||
det = detections[idet]
|
|
||||||
if track.state == TrackState.Tracked:
|
|
||||||
track.update(detections[idet], self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
else:
|
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
refind_stracks.append(track)
|
|
||||||
|
|
||||||
|
|
||||||
# detections = [detections[i] for i in u_detection]
|
|
||||||
# dists = matching.embedding_distance(self.lost_stracks, detections)
|
|
||||||
# iou_dists = matching.iou_distance(self.lost_stracks, detections)
|
|
||||||
# dists[np.where(iou_dists>0.7)] = 1.0
|
|
||||||
#
|
|
||||||
# matches, u_track_lost, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
#
|
|
||||||
# for itracked, idet in matches:
|
|
||||||
# track = self.lost_stracks[itracked]
|
|
||||||
# det = detections[idet]
|
|
||||||
# if track.state == TrackState.Tracked:
|
|
||||||
# track.update(detections[idet], self.frame_id)
|
|
||||||
# activated_starcks.append(track)
|
|
||||||
# else:
|
|
||||||
# track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
# refind_stracks.append(track)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''Remained det/track, use IOU between dets and tracks to associate directly'''
|
|
||||||
detections = [detections[i] for i in u_detection]
|
|
||||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ]
|
|
||||||
r_lost_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state!=TrackState.Tracked ]
|
|
||||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
|
||||||
|
|
||||||
for itracked, idet in matches:
|
|
||||||
track = r_tracked_stracks[itracked]
|
|
||||||
det = detections[idet]
|
|
||||||
if track.state == TrackState.Tracked:
|
|
||||||
track.update(det, self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
else:
|
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
refind_stracks.append(track)
|
|
||||||
|
|
||||||
for it in u_track:
|
|
||||||
track = r_tracked_stracks[it]
|
|
||||||
if not track.state == TrackState.Lost:
|
|
||||||
track.mark_lost()
|
|
||||||
lost_stracks.append(track)
|
|
||||||
|
|
||||||
# '''Remained det/track, use IOU between dets and tracks to associate directly'''
|
|
||||||
# detections = [detections[i] for i in u_detection]
|
|
||||||
# dists = matching.iou_distance(r_lost_stracks, detections)
|
|
||||||
# matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.25)
|
|
||||||
#
|
|
||||||
# for itracked, idet in matches:
|
|
||||||
# track = r_lost_stracks[itracked]
|
|
||||||
# det = detections[idet]
|
|
||||||
# if track.state == TrackState.Tracked:
|
|
||||||
# track.update(det, self.frame_id)
|
|
||||||
# activated_starcks.append(track)
|
|
||||||
# else:
|
|
||||||
# track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
# refind_stracks.append(track)
|
|
||||||
#
|
|
||||||
# for it in u_track:
|
|
||||||
# track = r_lost_stracks[it]
|
|
||||||
# if not track.state == TrackState.Lost:
|
|
||||||
# track.mark_lost()
|
|
||||||
# lost_stracks.append(track)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
|
||||||
detections = [detections[i] for i in u_detection]
|
|
||||||
dists = matching.iou_distance(unconfirmed, detections)
|
|
||||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
for itracked, idet in matches:
|
|
||||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
|
||||||
activated_starcks.append(unconfirmed[itracked])
|
|
||||||
for it in u_unconfirmed:
|
|
||||||
track = unconfirmed[it]
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
|
|
||||||
"""step 4: init new stracks"""
|
|
||||||
for inew in u_detection:
|
|
||||||
track = detections[inew]
|
|
||||||
if track.score < self.det_thresh:
|
|
||||||
continue
|
|
||||||
track.activate(self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
|
|
||||||
"""step 6: update state"""
|
|
||||||
for track in self.lost_stracks:
|
|
||||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
t4 = time.time()
|
|
||||||
# print('Ramained match {} s'.format(t4-t3))
|
|
||||||
|
|
||||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
|
||||||
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
|
||||||
self.lost_stracks.extend(lost_stracks)
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
|
||||||
self.removed_stracks.extend(removed_stracks)
|
|
||||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
|
||||||
|
|
||||||
# get scores of lost tracks
|
|
||||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
|
||||||
|
|
||||||
logger.debug('===========Frame {}=========='.format(self.frame_id))
|
|
||||||
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
|
|
||||||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
|
||||||
t5 = time.time()
|
|
||||||
# print('Final {} s'.format(t5-t4))
|
|
||||||
return output_stracks
|
|
||||||
|
|
||||||
def joint_stracks(tlista, tlistb):
|
|
||||||
exists = {}
|
|
||||||
res = []
|
|
||||||
for t in tlista:
|
|
||||||
exists[t.track_id] = 1
|
|
||||||
res.append(t)
|
|
||||||
for t in tlistb:
|
|
||||||
tid = t.track_id
|
|
||||||
if not exists.get(tid, 0):
|
|
||||||
exists[tid] = 1
|
|
||||||
res.append(t)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def sub_stracks(tlista, tlistb):
|
|
||||||
stracks = {}
|
|
||||||
for t in tlista:
|
|
||||||
stracks[t.track_id] = t
|
|
||||||
for t in tlistb:
|
|
||||||
tid = t.track_id
|
|
||||||
if stracks.get(tid, 0):
|
|
||||||
del stracks[tid]
|
|
||||||
return list(stracks.values())
|
|
||||||
|
|
||||||
def remove_duplicate_stracks(stracksa, stracksb):
|
|
||||||
pdist = matching.iou_distance(stracksa, stracksb)
|
|
||||||
pairs = np.where(pdist<0.15)
|
|
||||||
dupa, dupb = list(), list()
|
|
||||||
for p,q in zip(*pairs):
|
|
||||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
|
||||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
|
||||||
if timep > timeq:
|
|
||||||
dupb.append(q)
|
|
||||||
else:
|
|
||||||
dupa.append(p)
|
|
||||||
resa = [t for i,t in enumerate(stracksa) if not i in dupa]
|
|
||||||
resb = [t for i,t in enumerate(stracksb) if not i in dupb]
|
|
||||||
return resa, resb
|
|
||||||
|
|
||||||
|
|
|
@ -149,144 +149,13 @@ class STrack(BaseTrack):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||||
|
|
||||||
class IOUTracker(object):
|
|
||||||
def __init__(self, opt, frame_rate=30):
|
|
||||||
self.opt = opt
|
|
||||||
self.model = Darknet(opt.cfg, opt.img_size, nID=14455)
|
|
||||||
#load_darknet_weights(self.model, opt.weights)
|
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'])
|
|
||||||
self.model.cuda().eval()
|
|
||||||
|
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
|
||||||
self.lost_stracks = [] # type: list[STrack]
|
|
||||||
self.removed_stracks = [] # type: list[STrack]
|
|
||||||
|
|
||||||
self.frame_id = 0
|
class JDETracker(object):
|
||||||
self.det_thresh = opt.conf_thres
|
|
||||||
self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
|
|
||||||
self.max_time_lost = self.buffer_size
|
|
||||||
#self.fmap_buffer = deque([], maxlen=self.buffer_size)
|
|
||||||
|
|
||||||
def update(self, im_blob, img0):
|
|
||||||
self.frame_id += 1
|
|
||||||
activated_starcks = []
|
|
||||||
refind_stracks = []
|
|
||||||
lost_stracks = []
|
|
||||||
removed_stracks = []
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
'''Forward'''
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.model(im_blob)
|
|
||||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
|
||||||
if len(pred) > 0:
|
|
||||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0]
|
|
||||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
|
||||||
'''Detections'''
|
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh((t, l, b, r)), s, None) for (t, l, b, r, s) in dets[:, :5]]
|
|
||||||
else:
|
|
||||||
detections = []
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
#print('Forward: {} s'.format(t2-t1))
|
|
||||||
|
|
||||||
|
|
||||||
'''matching for tracked targets'''
|
|
||||||
unconfirmed = []
|
|
||||||
tracked_stracks = [] # type: list[STrack]
|
|
||||||
for track in self.tracked_stracks:
|
|
||||||
if not track.is_activated:
|
|
||||||
unconfirmed.append(track)
|
|
||||||
else:
|
|
||||||
tracked_stracks.append(track)
|
|
||||||
|
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
|
||||||
#dists = self.track_matching(strack_pool, detections, base_feat)
|
|
||||||
dists = matching.iou_distance(strack_pool, detections)
|
|
||||||
#dists[np.where(iou_dists>0.4)] = 1.0
|
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
|
|
||||||
for itracked, idet in matches:
|
|
||||||
track = strack_pool[itracked]
|
|
||||||
det = detections[idet]
|
|
||||||
if track.state == TrackState.Tracked:
|
|
||||||
track.update(detections[idet], self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
else:
|
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
|
||||||
refind_stracks.append(track)
|
|
||||||
t3 = time.time()
|
|
||||||
#print('First match {} s'.format(t3-t2))
|
|
||||||
|
|
||||||
#'''Remained det/track, use IOU between dets and tracks to associate directly'''
|
|
||||||
#detections = [detections[i] for i in u_detection]
|
|
||||||
#r_tracked_stracks = [strack_pool[i] for i in u_track ]
|
|
||||||
#dists = matching.iou_distance(r_tracked_stracks, detections)
|
|
||||||
#matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
#for itracked, idet in matches:
|
|
||||||
# r_tracked_stracks[itracked].update(detections[idet], self.frame_id)
|
|
||||||
for it in u_track:
|
|
||||||
track = strack_pool[it]
|
|
||||||
if not track.state == TrackState.Lost:
|
|
||||||
track.mark_lost()
|
|
||||||
lost_stracks.append(track)
|
|
||||||
|
|
||||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
|
||||||
detections = [detections[i] for i in u_detection]
|
|
||||||
dists = matching.iou_distance(unconfirmed, detections)
|
|
||||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
|
||||||
for itracked, idet in matches:
|
|
||||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
|
||||||
for it in u_unconfirmed:
|
|
||||||
track = unconfirmed[it]
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
|
|
||||||
"""step 4: init new stracks"""
|
|
||||||
for inew in u_detection:
|
|
||||||
track = detections[inew]
|
|
||||||
if track.score < self.det_thresh:
|
|
||||||
continue
|
|
||||||
track.activate(self.kalman_filter, self.frame_id)
|
|
||||||
activated_starcks.append(track)
|
|
||||||
|
|
||||||
"""step 6: update state"""
|
|
||||||
for track in self.lost_stracks:
|
|
||||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
|
||||||
track.mark_removed()
|
|
||||||
removed_stracks.append(track)
|
|
||||||
t4 = time.time()
|
|
||||||
#print('Ramained match {} s'.format(t4-t3))
|
|
||||||
|
|
||||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
|
||||||
#self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
|
||||||
self.lost_stracks.extend(lost_stracks)
|
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
|
||||||
self.removed_stracks.extend(removed_stracks)
|
|
||||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
|
||||||
|
|
||||||
# get scores of lost tracks
|
|
||||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
|
||||||
|
|
||||||
logger.debug('===========Frame {}=========='.format(self.frame_id))
|
|
||||||
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
|
|
||||||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
|
||||||
t5 = time.time()
|
|
||||||
#print('Final {} s'.format(t5-t4))
|
|
||||||
return output_stracks
|
|
||||||
|
|
||||||
|
|
||||||
class AETracker(object):
|
|
||||||
def __init__(self, opt, frame_rate=30):
|
def __init__(self, opt, frame_rate=30):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.model = Darknet(opt.cfg, opt.img_size, nID=14455)
|
self.model = Darknet(opt.cfg, opt.img_size, nID=14455)
|
||||||
# load_darknet_weights(self.model, opt.weights)
|
# load_darknet_weights(self.model, opt.weights)
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'])
|
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||||
self.model.cuda().eval()
|
self.model.cuda().eval()
|
||||||
|
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
self.tracked_stracks = [] # type: list[STrack]
|
0
utils/datasets.py
Executable file → Normal file
0
utils/datasets.py
Executable file → Normal file
Loading…
Reference in a new issue