diff --git a/sort.py b/sort.py index db2f6dc..089ec1b 100644 --- a/sort.py +++ b/sort.py @@ -25,8 +25,7 @@ matplotlib.use('TkAgg') import matplotlib.pyplot as plt import matplotlib.patches as patches from skimage import io -#from sklearn.utils.linear_assignment_ import linear_assignment -import lap + import glob import time import argparse @@ -34,6 +33,19 @@ from filterpy.kalman import KalmanFilter np.random.seed(0) + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i],i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + + @jit def iou(bb_test, bb_gt): """ @@ -155,14 +167,11 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3): if min(iou_matrix.shape) > 0: a = (iou_matrix > iou_threshold).astype(np.int32) if a.sum(1).max() == 1 and a.sum(0).max() == 1: - #matrix is #TODO(this doesnt provide much gains) matched_indices = np.stack(np.where(a), axis=1) else: - _, x, y = lap.lapjv(-iou_matrix, extend_cost=True) - matched_indices = np.array([[y[i],i] for i in x if i >= 0]) # + matched_indices = linear_assignment(-iou_matrix) else: matched_indices = np.empty(shape=(0,2)) - #matched_indices = linear_assignment(-iou_matrix) unmatched_detections = [] for d, det in enumerate(detections): @@ -271,6 +280,7 @@ if __name__ == '__main__': exit() plt.ion() fig = plt.figure() + ax1 = fig.add_subplot(111, aspect='equal') if not os.path.exists('output'): os.makedirs('output') @@ -290,7 +300,6 @@ if __name__ == '__main__': total_frames += 1 if(display): - ax1 = fig.add_subplot(111, aspect='equal') fn = 'mot_benchmark/%s/%s/img1/%06d.jpg'%(phase, seq, frame) im =io.imread(fn) ax1.imshow(im) @@ -313,7 +322,7 @@ if __name__ == '__main__': plt.draw() ax1.cla() - print("Total Tracking took: %.3f for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time)) + print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time)) if(display): print("Note: to get real runtime results run without the option: --display")