From c1b808481200d7997413c6e831bc59c8b08b61b8 Mon Sep 17 00:00:00 2001 From: Alex Bewley Date: Sun, 19 Jul 2020 23:27:59 +0200 Subject: [PATCH] Remove numba dependency and expose hparams flags. --- requirements.txt | 1 - sort.py | 59 +++++++++++++++++++++++++----------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index 53332b4..5c13119 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ filterpy==1.4.5 -numba==0.49.0 scikit-image==0.14.0 lap==0.4.0 diff --git a/sort.py b/sort.py index 651a7f7..bb80efd 100644 --- a/sort.py +++ b/sort.py @@ -30,12 +30,6 @@ import time import argparse from filterpy.kalman import KalmanFilter -try: - from numba import jit -except: - def jit(func): - return func - np.random.seed(0) @@ -50,21 +44,23 @@ def linear_assignment(cost_matrix): return np.array(list(zip(x, y))) -@jit -def iou(bb_test, bb_gt): - """ - Computes IUO between two bboxes in the form [x1,y1,x2,y2] - """ - xx1 = np.maximum(bb_test[0], bb_gt[0]) - yy1 = np.maximum(bb_test[1], bb_gt[1]) - xx2 = np.minimum(bb_test[2], bb_gt[2]) - yy2 = np.minimum(bb_test[3], bb_gt[3]) - w = np.maximum(0., xx2 - xx1) - h = np.maximum(0., yy2 - yy1) - wh = w * h - o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1]) - + (bb_gt[2] - bb_gt[0]) * (bb_gt[3] - bb_gt[1]) - wh) - return(o) +def iou_batch(bb_test, bb_gt): + """ + From SORT: Computes IUO between two bboxes in the form [l,t,w,h] + """ + bb_gt = np.expand_dims(bb_gt, 0) + bb_test = np.expand_dims(bb_test, 1) + + xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0]) + yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1]) + xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2]) + yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1]) + + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh) + return(o) def convert_bbox_to_z(bbox): @@ -163,11 +159,8 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3): """ if(len(trackers)==0): return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int) - iou_matrix = np.zeros((len(detections),len(trackers)),dtype=np.float32) - for d,det in enumerate(detections): - for t,trk in enumerate(trackers): - iou_matrix[d,t] = iou(det,trk) + iou_matrix = iou_batch(detections, trackers) if min(iou_matrix.shape) > 0: a = (iou_matrix > iou_threshold).astype(np.int32) @@ -204,12 +197,13 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3): class Sort(object): - def __init__(self, max_age=1, min_hits=3): + def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3): """ Sets key parameters for SORT """ self.max_age = max_age self.min_hits = min_hits + self.iou_threshold = iou_threshold self.trackers = [] self.frame_count = 0 @@ -235,7 +229,7 @@ class Sort(object): trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) for t in reversed(to_del): self.trackers.pop(t) - matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks) + matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks, self.iou_threshold) # update matched trackers with assigned detections for m in matched: @@ -264,6 +258,13 @@ def parse_args(): parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]',action='store_true') parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data') parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train') + parser.add_argument("--max_age", + help="Maximum number of frames to keep alive a track without associated detections.", + type=int, default=1) + parser.add_argument("--min_hits", + help="Minimum number of associated detections before track is initialised.", + type=int, default=3) + parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3) args = parser.parse_args() return args @@ -287,7 +288,9 @@ if __name__ == '__main__': os.makedirs('output') pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt') for seq_dets_fn in glob.glob(pattern): - mot_tracker = Sort() #create instance of the SORT tracker + mot_tracker = Sort(max_age=args.max_age, + min_hits=args.min_hits, + iou_threshold=args.iou_threshold) #create instance of the SORT tracker seq_dets = np.loadtxt(seq_dets_fn, delimiter=',') seq = seq_dets_fn[pattern.find('*'):].split('/')[0]