Towards-Realtime-MOT/tracker/detector.py

182 lines
5.9 KiB
Python
Raw Normal View History

2019-09-27 05:37:47 +00:00
import numpy as np
from numba import jit
from collections import deque
import itertools
import os
import os.path as osp
import time
import torch
from lib.utils.log import logger
from lib.tracker import matching
from lib.utils.kalman_filter import KalmanFilter
from lib.model.faster_rcnn.resnet import resnet_deploy
from lib.model.utils.config import cfg
from lib.model.rpn.bbox_transform import clip_boxes, bbox_transform_inv
from lib.model.nms.nms_wrapper import nms
from .basetrack import BaseTrack, TrackState
class STrack(BaseTrack):
def __init__(self, tlwh, score, temp_feat):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self.is_activated = False
self.score = score
self.tracklet_len = 0
self.temp_feat = temp_feat
def activate(self, frame_id):
"""Start a new tracklet"""
self.track_id = self.next_id()
self.time_since_update = 0
self.tracklet_len = 0
self.state = TrackState.Tracked
#self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self._tlwh = new_track.tlwh
self.temp_feat = new_track.temp_feat
self.time_since_update = 0
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
def update(self, new_track, frame_id, update_feature=True):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
"""
self.frame_id = frame_id
self.time_since_update = 0
self.tracklet_len += 1
self._tlwh = new_track.tlwh
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
if update_feature:
self.temp_feat = new_track.temp_feat
@property
@jit
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
return self._tlwh.copy()
@property
@jit
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
@jit
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
@jit
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
class JDTracker(object):
def __init__(self, checksession=3, checkepoch=24, checkpoint=663, det_thresh=0.92, frame_rate=30):
self.classes = np.asarray(['__background__', 'pedestrian'])
self.fasterRCNN = resnet_deploy(self.classes, 101, pretrained=False, class_agnostic=False)
self.fasterRCNN.create_architecture()
input_dir = osp.join('models', 'res101', 'mot17det')
if not os.path.exists(input_dir):
raise Exception('There is no input directory for loading network from ' + input_dir)
load_name = os.path.join(input_dir,
'faster_rcnn_{}_{}_{}.pth'.format(checksession, checkepoch, checkpoint))
print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
self.fasterRCNN.load_state_dict(checkpoint['model'], strict=False)
print('load model successfully!')
self.fasterRCNN.cuda()
self.fasterRCNN.eval()
self.frame_id = 0
self.det_thresh = det_thresh
self.buffer_size = int(frame_rate / 30.0 * cfg.TRACKING_BUFFER_SIZE)
self.max_time_lost = self.buffer_size
#self.fmap_buffer = deque([], maxlen=self.buffer_size)
def update(self, im_blob):
self.frame_id += 1
'''Forward'''
im_blob = im_blob.cuda()
im_info = torch.Tensor([[im_blob.shape[1], im_blob.shape[2], 1, ],]).float().cuda()
self.im_info = im_info
boxes, temp_feat, base_feat = self.predict(im_blob, im_info)
'''Detections'''
detections = [STrack(STrack.tlbr_to_tlwh((t, l, b, r)), s, f) for (t, l, b, r, s), f in zip(boxes, temp_feat)]
return detections
def predict(self, im_blob, im_info):
im_blob = im_blob.permute(0,3,1,2)
# Trivial input
gt_boxes = torch.zeros(1, 1, 6).to(im_blob)
num_boxes = gt_boxes[:, :, 0].squeeze()
with torch.no_grad():
rois, cls_prob, bbox_pred, base_feat = self.fasterRCNN(im_blob, im_info, gt_boxes, num_boxes)
scores = cls_prob.data
inds_first = torch.nonzero(scores[0, :, 1] > self.det_thresh).view(-1)
if inds_first.numel() > 0:
rois = rois[:, inds_first]
scores = scores[:,inds_first]
bbox_pred = bbox_pred[:, inds_first]
refined_rois = self.fasterRCNN.bbox_refine(rois, bbox_pred, im_info)
template_feat = self.fasterRCNN.roi_pool(base_feat, refined_rois)
pred_boxes = refined_rois.data[:, :, 1:5]
cls_scores = scores[0, :, 1]
_, order = torch.sort(cls_scores, 0, True)
cls_boxes = pred_boxes[0]
cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
cls_dets = cls_dets[order]
temp_feat = template_feat[order]
keep_first = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS).view(-1).long()
cls_dets = cls_dets[keep_first]
temp_feat = temp_feat[keep_first]
output_box = cls_dets.cpu().numpy()
else:
output_box = []
temp_feat = []
return output_box, temp_feat, base_feat