Add scipy fallback for linear assignment.

This commit is contained in:
Alex Bewley 2020-01-17 00:19:36 +01:00
parent 6e093f5be1
commit 83b7714602

25
sort.py
View file

@ -25,8 +25,7 @@ matplotlib.use('TkAgg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as patches import matplotlib.patches as patches
from skimage import io from skimage import io
#from sklearn.utils.linear_assignment_ import linear_assignment
import lap
import glob import glob
import time import time
import argparse import argparse
@ -34,6 +33,19 @@ from filterpy.kalman import KalmanFilter
np.random.seed(0) np.random.seed(0)
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i],i] for i in x if i >= 0]) #
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
@jit @jit
def iou(bb_test, bb_gt): def iou(bb_test, bb_gt):
""" """
@ -155,14 +167,11 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
if min(iou_matrix.shape) > 0: if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32) a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1: if a.sum(1).max() == 1 and a.sum(0).max() == 1:
#matrix is #TODO(this doesnt provide much gains)
matched_indices = np.stack(np.where(a), axis=1) matched_indices = np.stack(np.where(a), axis=1)
else: else:
_, x, y = lap.lapjv(-iou_matrix, extend_cost=True) matched_indices = linear_assignment(-iou_matrix)
matched_indices = np.array([[y[i],i] for i in x if i >= 0]) #
else: else:
matched_indices = np.empty(shape=(0,2)) matched_indices = np.empty(shape=(0,2))
#matched_indices = linear_assignment(-iou_matrix)
unmatched_detections = [] unmatched_detections = []
for d, det in enumerate(detections): for d, det in enumerate(detections):
@ -271,6 +280,7 @@ if __name__ == '__main__':
exit() exit()
plt.ion() plt.ion()
fig = plt.figure() fig = plt.figure()
ax1 = fig.add_subplot(111, aspect='equal')
if not os.path.exists('output'): if not os.path.exists('output'):
os.makedirs('output') os.makedirs('output')
@ -290,7 +300,6 @@ if __name__ == '__main__':
total_frames += 1 total_frames += 1
if(display): if(display):
ax1 = fig.add_subplot(111, aspect='equal')
fn = 'mot_benchmark/%s/%s/img1/%06d.jpg'%(phase, seq, frame) fn = 'mot_benchmark/%s/%s/img1/%06d.jpg'%(phase, seq, frame)
im =io.imread(fn) im =io.imread(fn)
ax1.imshow(im) ax1.imshow(im)
@ -313,7 +322,7 @@ if __name__ == '__main__':
plt.draw() plt.draw()
ax1.cla() ax1.cla()
print("Total Tracking took: %.3f for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time)) print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
if(display): if(display):
print("Note: to get real runtime results run without the option: --display") print("Note: to get real runtime results run without the option: --display")