Remove numba dependency and expose hparams flags.
This commit is contained in:
parent
4c0bdb0935
commit
c1b8084812
2 changed files with 31 additions and 29 deletions
|
@ -1,4 +1,3 @@
|
||||||
filterpy==1.4.5
|
filterpy==1.4.5
|
||||||
numba==0.49.0
|
|
||||||
scikit-image==0.14.0
|
scikit-image==0.14.0
|
||||||
lap==0.4.0
|
lap==0.4.0
|
||||||
|
|
47
sort.py
47
sort.py
|
@ -30,12 +30,6 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
from filterpy.kalman import KalmanFilter
|
from filterpy.kalman import KalmanFilter
|
||||||
|
|
||||||
try:
|
|
||||||
from numba import jit
|
|
||||||
except:
|
|
||||||
def jit(func):
|
|
||||||
return func
|
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,20 +44,22 @@ def linear_assignment(cost_matrix):
|
||||||
return np.array(list(zip(x, y)))
|
return np.array(list(zip(x, y)))
|
||||||
|
|
||||||
|
|
||||||
@jit
|
def iou_batch(bb_test, bb_gt):
|
||||||
def iou(bb_test, bb_gt):
|
|
||||||
"""
|
"""
|
||||||
Computes IUO between two bboxes in the form [x1,y1,x2,y2]
|
From SORT: Computes IUO between two bboxes in the form [l,t,w,h]
|
||||||
"""
|
"""
|
||||||
xx1 = np.maximum(bb_test[0], bb_gt[0])
|
bb_gt = np.expand_dims(bb_gt, 0)
|
||||||
yy1 = np.maximum(bb_test[1], bb_gt[1])
|
bb_test = np.expand_dims(bb_test, 1)
|
||||||
xx2 = np.minimum(bb_test[2], bb_gt[2])
|
|
||||||
yy2 = np.minimum(bb_test[3], bb_gt[3])
|
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
|
||||||
|
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
|
||||||
|
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
|
||||||
|
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
|
||||||
w = np.maximum(0., xx2 - xx1)
|
w = np.maximum(0., xx2 - xx1)
|
||||||
h = np.maximum(0., yy2 - yy1)
|
h = np.maximum(0., yy2 - yy1)
|
||||||
wh = w * h
|
wh = w * h
|
||||||
o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1])
|
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
|
||||||
+ (bb_gt[2] - bb_gt[0]) * (bb_gt[3] - bb_gt[1]) - wh)
|
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
|
||||||
return(o)
|
return(o)
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,11 +159,8 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
|
||||||
"""
|
"""
|
||||||
if(len(trackers)==0):
|
if(len(trackers)==0):
|
||||||
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
||||||
iou_matrix = np.zeros((len(detections),len(trackers)),dtype=np.float32)
|
|
||||||
|
|
||||||
for d,det in enumerate(detections):
|
iou_matrix = iou_batch(detections, trackers)
|
||||||
for t,trk in enumerate(trackers):
|
|
||||||
iou_matrix[d,t] = iou(det,trk)
|
|
||||||
|
|
||||||
if min(iou_matrix.shape) > 0:
|
if min(iou_matrix.shape) > 0:
|
||||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||||
|
@ -204,12 +197,13 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
|
||||||
|
|
||||||
|
|
||||||
class Sort(object):
|
class Sort(object):
|
||||||
def __init__(self, max_age=1, min_hits=3):
|
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
|
||||||
"""
|
"""
|
||||||
Sets key parameters for SORT
|
Sets key parameters for SORT
|
||||||
"""
|
"""
|
||||||
self.max_age = max_age
|
self.max_age = max_age
|
||||||
self.min_hits = min_hits
|
self.min_hits = min_hits
|
||||||
|
self.iou_threshold = iou_threshold
|
||||||
self.trackers = []
|
self.trackers = []
|
||||||
self.frame_count = 0
|
self.frame_count = 0
|
||||||
|
|
||||||
|
@ -235,7 +229,7 @@ class Sort(object):
|
||||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||||
for t in reversed(to_del):
|
for t in reversed(to_del):
|
||||||
self.trackers.pop(t)
|
self.trackers.pop(t)
|
||||||
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks)
|
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks, self.iou_threshold)
|
||||||
|
|
||||||
# update matched trackers with assigned detections
|
# update matched trackers with assigned detections
|
||||||
for m in matched:
|
for m in matched:
|
||||||
|
@ -264,6 +258,13 @@ def parse_args():
|
||||||
parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]',action='store_true')
|
parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]',action='store_true')
|
||||||
parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
|
parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
|
||||||
parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
|
parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
|
||||||
|
parser.add_argument("--max_age",
|
||||||
|
help="Maximum number of frames to keep alive a track without associated detections.",
|
||||||
|
type=int, default=1)
|
||||||
|
parser.add_argument("--min_hits",
|
||||||
|
help="Minimum number of associated detections before track is initialised.",
|
||||||
|
type=int, default=3)
|
||||||
|
parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -287,7 +288,9 @@ if __name__ == '__main__':
|
||||||
os.makedirs('output')
|
os.makedirs('output')
|
||||||
pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
|
pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
|
||||||
for seq_dets_fn in glob.glob(pattern):
|
for seq_dets_fn in glob.glob(pattern):
|
||||||
mot_tracker = Sort() #create instance of the SORT tracker
|
mot_tracker = Sort(max_age=args.max_age,
|
||||||
|
min_hits=args.min_hits,
|
||||||
|
iou_threshold=args.iou_threshold) #create instance of the SORT tracker
|
||||||
seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
|
seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
|
||||||
seq = seq_dets_fn[pattern.find('*'):].split('/')[0]
|
seq = seq_dets_fn[pattern.find('*'):].split('/')[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue