Add scipy fallback for linear assignment.
This commit is contained in:
parent
6e093f5be1
commit
83b7714602
1 changed files with 17 additions and 8 deletions
25
sort.py
25
sort.py
|
@ -25,8 +25,7 @@ matplotlib.use('TkAgg')
|
|||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
from skimage import io
|
||||
#from sklearn.utils.linear_assignment_ import linear_assignment
|
||||
import lap
|
||||
|
||||
import glob
|
||||
import time
|
||||
import argparse
|
||||
|
@ -34,6 +33,19 @@ from filterpy.kalman import KalmanFilter
|
|||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
try:
|
||||
import lap
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
||||
return np.array([[y[i],i] for i in x if i >= 0]) #
|
||||
except ImportError:
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
def iou(bb_test, bb_gt):
|
||||
"""
|
||||
|
@ -155,14 +167,11 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
|
|||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
#matrix is #TODO(this doesnt provide much gains)
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
_, x, y = lap.lapjv(-iou_matrix, extend_cost=True)
|
||||
matched_indices = np.array([[y[i],i] for i in x if i >= 0]) #
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0,2))
|
||||
#matched_indices = linear_assignment(-iou_matrix)
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
|
@ -271,6 +280,7 @@ if __name__ == '__main__':
|
|||
exit()
|
||||
plt.ion()
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(111, aspect='equal')
|
||||
|
||||
if not os.path.exists('output'):
|
||||
os.makedirs('output')
|
||||
|
@ -290,7 +300,6 @@ if __name__ == '__main__':
|
|||
total_frames += 1
|
||||
|
||||
if(display):
|
||||
ax1 = fig.add_subplot(111, aspect='equal')
|
||||
fn = 'mot_benchmark/%s/%s/img1/%06d.jpg'%(phase, seq, frame)
|
||||
im =io.imread(fn)
|
||||
ax1.imshow(im)
|
||||
|
@ -313,7 +322,7 @@ if __name__ == '__main__':
|
|||
plt.draw()
|
||||
ax1.cla()
|
||||
|
||||
print("Total Tracking took: %.3f for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
|
||||
print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
|
||||
|
||||
if(display):
|
||||
print("Note: to get real runtime results run without the option: --display")
|
||||
|
|
Loading…
Reference in a new issue