2020-01-09 15:48:17 +01:00
from numba import jit
from collections import deque
import torch
from utils . kalman_filter import KalmanFilter
2020-03-20 12:45:22 +01:00
from utils . log import logger
2020-01-09 15:48:17 +01:00
from models import *
from tracker import matching
from . basetrack import BaseTrack , TrackState
class STrack ( BaseTrack ) :
def __init__ ( self , tlwh , score , temp_feat , buffer_size = 30 ) :
# wait activate
self . _tlwh = np . asarray ( tlwh , dtype = np . float )
self . kalman_filter = None
self . mean , self . covariance = None , None
self . is_activated = False
self . score = score
self . tracklet_len = 0
self . smooth_feat = None
self . update_features ( temp_feat )
self . features = deque ( [ ] , maxlen = buffer_size )
self . alpha = 0.9
def update_features ( self , feat ) :
2020-01-10 12:37:29 +01:00
feat / = np . linalg . norm ( feat )
self . curr_feat = feat
2020-01-09 15:48:17 +01:00
if self . smooth_feat is None :
self . smooth_feat = feat
else :
self . smooth_feat = self . alpha * self . smooth_feat + ( 1 - self . alpha ) * feat
self . features . append ( feat )
2020-01-29 14:45:07 +01:00
self . smooth_feat / = np . linalg . norm ( self . smooth_feat )
2020-01-09 15:48:17 +01:00
def predict ( self ) :
mean_state = self . mean . copy ( )
if self . state != TrackState . Tracked :
mean_state [ 7 ] = 0
self . mean , self . covariance = self . kalman_filter . predict ( mean_state , self . covariance )
2020-03-14 03:24:27 +01:00
2020-01-29 14:45:07 +01:00
@staticmethod
2020-03-20 12:45:22 +01:00
def multi_predict ( stracks , kalman_filter ) :
2020-01-29 14:45:07 +01:00
if len ( stracks ) > 0 :
multi_mean = np . asarray ( [ st . mean . copy ( ) for st in stracks ] )
multi_covariance = np . asarray ( [ st . covariance for st in stracks ] )
2020-03-14 03:24:27 +01:00
for i , st in enumerate ( stracks ) :
2020-01-29 14:45:07 +01:00
if st . state != TrackState . Tracked :
multi_mean [ i ] [ 7 ] = 0
2020-03-20 12:45:22 +01:00
# multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
multi_mean , multi_covariance = kalman_filter . multi_predict ( multi_mean , multi_covariance )
2020-01-29 14:45:07 +01:00
for i , ( mean , cov ) in enumerate ( zip ( multi_mean , multi_covariance ) ) :
stracks [ i ] . mean = mean
stracks [ i ] . covariance = cov
2020-01-09 15:48:17 +01:00
def activate ( self , kalman_filter , frame_id ) :
""" Start a new tracklet """
self . kalman_filter = kalman_filter
self . track_id = self . next_id ( )
self . mean , self . covariance = self . kalman_filter . initiate ( self . tlwh_to_xyah ( self . _tlwh ) )
self . tracklet_len = 0
self . state = TrackState . Tracked
#self.is_activated = True
self . frame_id = frame_id
self . start_frame = frame_id
def re_activate ( self , new_track , frame_id , new_id = False ) :
self . mean , self . covariance = self . kalman_filter . update (
self . mean , self . covariance , self . tlwh_to_xyah ( new_track . tlwh )
)
self . update_features ( new_track . curr_feat )
self . tracklet_len = 0
self . state = TrackState . Tracked
self . is_activated = True
self . frame_id = frame_id
if new_id :
self . track_id = self . next_id ( )
def update ( self , new_track , frame_id , update_feature = True ) :
"""
Update a matched track
: type new_track : STrack
: type frame_id : int
: type update_feature : bool
: return :
"""
self . frame_id = frame_id
self . tracklet_len + = 1
new_tlwh = new_track . tlwh
self . mean , self . covariance = self . kalman_filter . update (
self . mean , self . covariance , self . tlwh_to_xyah ( new_tlwh ) )
self . state = TrackState . Tracked
self . is_activated = True
self . score = new_track . score
if update_feature :
self . update_features ( new_track . curr_feat )
@property
2020-03-14 03:24:27 +01:00
@jit
2020-01-09 15:48:17 +01:00
def tlwh ( self ) :
""" Get current position in bounding box format `(top left x, top left y,
width , height ) ` .
"""
if self . mean is None :
return self . _tlwh . copy ( )
ret = self . mean [ : 4 ] . copy ( )
ret [ 2 ] * = ret [ 3 ]
ret [ : 2 ] - = ret [ 2 : ] / 2
return ret
@property
2020-03-14 03:24:27 +01:00
@jit
2020-01-09 15:48:17 +01:00
def tlbr ( self ) :
""" Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
` ( top left , bottom right ) ` .
"""
ret = self . tlwh . copy ( )
ret [ 2 : ] + = ret [ : 2 ]
return ret
@staticmethod
2020-03-14 03:24:27 +01:00
@jit
2020-01-09 15:48:17 +01:00
def tlwh_to_xyah ( tlwh ) :
""" Convert bounding box to format `(center x, center y, aspect ratio,
height ) ` , where the aspect ratio is ` width / height ` .
"""
ret = np . asarray ( tlwh ) . copy ( )
ret [ : 2 ] + = ret [ 2 : ] / 2
ret [ 2 ] / = ret [ 3 ]
return ret
def to_xyah ( self ) :
return self . tlwh_to_xyah ( self . tlwh )
@staticmethod
2020-03-14 03:24:27 +01:00
@jit
2020-01-09 15:48:17 +01:00
def tlbr_to_tlwh ( tlbr ) :
ret = np . asarray ( tlbr ) . copy ( )
ret [ 2 : ] - = ret [ : 2 ]
return ret
@staticmethod
2020-03-14 03:24:27 +01:00
@jit
2020-01-09 15:48:17 +01:00
def tlwh_to_tlbr ( tlwh ) :
ret = np . asarray ( tlwh ) . copy ( )
ret [ 2 : ] + = ret [ : 2 ]
return ret
def __repr__ ( self ) :
return ' OT_ {} _( {} - {} ) ' . format ( self . track_id , self . start_frame , self . end_frame )
class JDETracker ( object ) :
def __init__ ( self , opt , frame_rate = 30 ) :
self . opt = opt
2020-03-20 12:45:22 +01:00
self . model = Darknet ( opt . cfg , nID = 14455 )
2020-01-09 15:48:17 +01:00
# load_darknet_weights(self.model, opt.weights)
self . model . load_state_dict ( torch . load ( opt . weights , map_location = ' cpu ' ) [ ' model ' ] , strict = False )
self . model . cuda ( ) . eval ( )
self . tracked_stracks = [ ] # type: list[STrack]
self . lost_stracks = [ ] # type: list[STrack]
self . removed_stracks = [ ] # type: list[STrack]
self . frame_id = 0
self . det_thresh = opt . conf_thres
self . buffer_size = int ( frame_rate / 30.0 * opt . track_buffer )
self . max_time_lost = self . buffer_size
self . kalman_filter = KalmanFilter ( )
def update ( self , im_blob , img0 ) :
2020-03-14 03:24:27 +01:00
"""
Processes the image frame and finds bounding box ( detections ) .
Associates the detection with corresponding tracklets and also handles lost , removed , refound and active tracklets
Parameters
- - - - - - - - - -
im_blob : torch . float32
Tensor of shape depending upon the size of image . By default , shape of this tensor is [ 1 , 3 , 608 , 1088 ]
img0 : ndarray
ndarray of shape depending on the input image sequence . By default , shape is [ 608 , 1080 , 3 ]
Returns
- - - - - - -
output_stracks : list of Strack ( instances )
The list contains information regarding the online_tracklets for the recieved image tensor .
"""
2020-01-09 15:48:17 +01:00
self . frame_id + = 1
2020-03-14 03:24:27 +01:00
activated_starcks = [ ] # for storing active tracks, for the current frame
refind_stracks = [ ] # Lost Tracks whose detections are obtained in the current frame
lost_stracks = [ ] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
2020-01-09 15:48:17 +01:00
removed_stracks = [ ]
t1 = time . time ( )
''' Step 1: Network forward, get detections & embeddings '''
with torch . no_grad ( ) :
pred = self . model ( im_blob )
2020-03-14 03:24:27 +01:00
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
2020-01-09 15:48:17 +01:00
pred = pred [ pred [ : , : , 4 ] > self . opt . conf_thres ]
2020-03-14 03:24:27 +01:00
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
2020-01-09 15:48:17 +01:00
if len ( pred ) > 0 :
2020-03-14 03:24:27 +01:00
dets = non_max_suppression ( pred . unsqueeze ( 0 ) , self . opt . conf_thres , self . opt . nms_thres ) [ 0 ] . cpu ( )
# Final proposals are obtained in dets. Information of bounding box and embeddings also included
# Next step changes the detection scales
2020-01-09 15:48:17 +01:00
scale_coords ( self . opt . img_size , dets [ : , : 4 ] , img0 . shape ) . round ( )
2020-03-14 03:24:27 +01:00
''' Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred) '''
# class_pred is the embeddings.
detections = [ STrack ( STrack . tlbr_to_tlwh ( tlbrs [ : 4 ] ) , tlbrs [ 4 ] , f . numpy ( ) , 30 ) for
( tlbrs , f ) in zip ( dets [ : , : 5 ] , dets [ : , 6 : ] ) ]
2020-01-09 15:48:17 +01:00
else :
detections = [ ]
2020-03-14 03:24:27 +01:00
t2 = time . time ( )
# print('Forward: {} s'.format(t2-t1))
2020-01-09 15:48:17 +01:00
''' Add newly detected tracklets to tracked_stracks '''
unconfirmed = [ ]
tracked_stracks = [ ] # type: list[STrack]
for track in self . tracked_stracks :
if not track . is_activated :
2020-03-14 03:24:27 +01:00
# previous tracks which are not active in the current frame are added in unconfirmed list
2020-01-09 15:48:17 +01:00
unconfirmed . append ( track )
2020-03-14 03:24:27 +01:00
# print("Should not be here, in unconfirmed")
2020-01-09 15:48:17 +01:00
else :
2020-03-14 03:24:27 +01:00
# Active tracks are added to the local list 'tracked_stracks'
2020-01-09 15:48:17 +01:00
tracked_stracks . append ( track )
''' Step 2: First association, with embedding '''
2020-03-14 03:24:27 +01:00
# Combining currently tracked_stracks and lost_stracks
2020-01-09 15:48:17 +01:00
strack_pool = joint_stracks ( tracked_stracks , self . lost_stracks )
# Predict the current location with KF
2020-03-20 12:45:22 +01:00
STrack . multi_predict ( strack_pool , self . kalman_filter )
2020-03-14 03:24:27 +01:00
2020-01-09 15:48:17 +01:00
dists = matching . embedding_distance ( strack_pool , detections )
2020-03-14 03:24:27 +01:00
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
2020-01-10 12:37:29 +01:00
dists = matching . fuse_motion ( self . kalman_filter , dists , strack_pool , detections )
2020-03-14 03:24:27 +01:00
# The dists is the list of distances of the detection with the tracks in strack_pool
2020-01-09 15:48:17 +01:00
matches , u_track , u_detection = matching . linear_assignment ( dists , thresh = 0.7 )
2020-03-14 03:24:27 +01:00
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
2020-01-09 15:48:17 +01:00
for itracked , idet in matches :
2020-03-14 03:24:27 +01:00
# itracked is the id of the track and idet is the detection
2020-01-09 15:48:17 +01:00
track = strack_pool [ itracked ]
det = detections [ idet ]
if track . state == TrackState . Tracked :
2020-03-14 03:24:27 +01:00
# If the track is active, add the detection to the track
2020-01-09 15:48:17 +01:00
track . update ( detections [ idet ] , self . frame_id )
activated_starcks . append ( track )
else :
2020-03-14 03:24:27 +01:00
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
2020-01-09 15:48:17 +01:00
track . re_activate ( det , self . frame_id , new_id = False )
refind_stracks . append ( track )
2020-03-14 03:24:27 +01:00
# None of the steps below happen if there are no undetected tracks.
2020-01-09 15:48:17 +01:00
''' Step 3: Second association, with IOU '''
detections = [ detections [ i ] for i in u_detection ]
2020-03-14 03:24:27 +01:00
# detections is now a list of the unmatched detections
r_tracked_stracks = [ ] # This is container for stracks which were tracked till the
# previous frame but no detection was found for it in the current frame
for i in u_track :
if strack_pool [ i ] . state == TrackState . Tracked :
r_tracked_stracks . append ( strack_pool [ i ] )
2020-01-09 15:48:17 +01:00
dists = matching . iou_distance ( r_tracked_stracks , detections )
matches , u_track , u_detection = matching . linear_assignment ( dists , thresh = 0.5 )
2020-03-14 03:24:27 +01:00
# matches is the list of detections which matched with corresponding tracks by IOU distance method
2020-01-09 15:48:17 +01:00
for itracked , idet in matches :
track = r_tracked_stracks [ itracked ]
det = detections [ idet ]
if track . state == TrackState . Tracked :
track . update ( det , self . frame_id )
activated_starcks . append ( track )
else :
track . re_activate ( det , self . frame_id , new_id = False )
refind_stracks . append ( track )
2020-03-14 03:24:27 +01:00
# Same process done for some unmatched detections, but now considering IOU_distance as measure
2020-01-09 15:48:17 +01:00
for it in u_track :
track = r_tracked_stracks [ it ]
if not track . state == TrackState . Lost :
track . mark_lost ( )
lost_stracks . append ( track )
2020-03-14 03:24:27 +01:00
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
2020-01-09 15:48:17 +01:00
''' Deal with unconfirmed tracks, usually tracks with only one beginning frame '''
detections = [ detections [ i ] for i in u_detection ]
dists = matching . iou_distance ( unconfirmed , detections )
matches , u_unconfirmed , u_detection = matching . linear_assignment ( dists , thresh = 0.7 )
for itracked , idet in matches :
unconfirmed [ itracked ] . update ( detections [ idet ] , self . frame_id )
activated_starcks . append ( unconfirmed [ itracked ] )
2020-03-14 03:24:27 +01:00
# The tracks which are yet not matched
2020-01-09 15:48:17 +01:00
for it in u_unconfirmed :
track = unconfirmed [ it ]
track . mark_removed ( )
removed_stracks . append ( track )
2020-03-14 03:24:27 +01:00
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
2020-01-09 15:48:17 +01:00
""" Step 4: Init new stracks """
for inew in u_detection :
track = detections [ inew ]
if track . score < self . det_thresh :
continue
track . activate ( self . kalman_filter , self . frame_id )
activated_starcks . append ( track )
""" Step 5: Update state """
2020-03-14 03:24:27 +01:00
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
2020-01-09 15:48:17 +01:00
for track in self . lost_stracks :
if self . frame_id - track . end_frame > self . max_time_lost :
track . mark_removed ( )
removed_stracks . append ( track )
2020-03-14 03:24:27 +01:00
# print('Remained match {} s'.format(t4-t3))
2020-01-09 15:48:17 +01:00
2020-03-14 03:24:27 +01:00
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
2020-01-09 15:48:17 +01:00
self . tracked_stracks = [ t for t in self . tracked_stracks if t . state == TrackState . Tracked ]
self . tracked_stracks = joint_stracks ( self . tracked_stracks , activated_starcks )
self . tracked_stracks = joint_stracks ( self . tracked_stracks , refind_stracks )
2020-03-14 03:24:27 +01:00
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
2020-01-09 15:48:17 +01:00
self . lost_stracks = sub_stracks ( self . lost_stracks , self . tracked_stracks )
self . lost_stracks . extend ( lost_stracks )
self . lost_stracks = sub_stracks ( self . lost_stracks , self . removed_stracks )
self . removed_stracks . extend ( removed_stracks )
self . tracked_stracks , self . lost_stracks = remove_duplicate_stracks ( self . tracked_stracks , self . lost_stracks )
# get scores of lost tracks
output_stracks = [ track for track in self . tracked_stracks if track . is_activated ]
logger . debug ( ' ===========Frame {} ========== ' . format ( self . frame_id ) )
logger . debug ( ' Activated: {} ' . format ( [ track . track_id for track in activated_starcks ] ) )
logger . debug ( ' Refind: {} ' . format ( [ track . track_id for track in refind_stracks ] ) )
logger . debug ( ' Lost: {} ' . format ( [ track . track_id for track in lost_stracks ] ) )
logger . debug ( ' Removed: {} ' . format ( [ track . track_id for track in removed_stracks ] ) )
2020-03-14 03:24:27 +01:00
# print('Final {} s'.format(t5-t4))
2020-01-09 15:48:17 +01:00
return output_stracks
def joint_stracks ( tlista , tlistb ) :
exists = { }
res = [ ]
for t in tlista :
exists [ t . track_id ] = 1
res . append ( t )
for t in tlistb :
tid = t . track_id
if not exists . get ( tid , 0 ) :
exists [ tid ] = 1
res . append ( t )
return res
def sub_stracks ( tlista , tlistb ) :
stracks = { }
for t in tlista :
stracks [ t . track_id ] = t
for t in tlistb :
tid = t . track_id
if stracks . get ( tid , 0 ) :
del stracks [ tid ]
return list ( stracks . values ( ) )
def remove_duplicate_stracks ( stracksa , stracksb ) :
pdist = matching . iou_distance ( stracksa , stracksb )
pairs = np . where ( pdist < 0.15 )
dupa , dupb = list ( ) , list ( )
for p , q in zip ( * pairs ) :
timep = stracksa [ p ] . frame_id - stracksa [ p ] . start_frame
timeq = stracksb [ q ] . frame_id - stracksb [ q ] . start_frame
if timep > timeq :
dupb . append ( q )
else :
dupa . append ( p )
resa = [ t for i , t in enumerate ( stracksa ) if not i in dupa ]
resb = [ t for i , t in enumerate ( stracksb ) if not i in dupb ]
return resa , resb