Compare commits

..

10 commits

Author SHA1 Message Date
Ruben van de Ven
547d82b2fa modern python support with lapx 2025-08-06 15:06:28 +02:00
Tim Stokman
97b1fd2180
Update README.md 2021-05-04 12:01:51 +02:00
Tim
b02e68c3b5 Remove unnecessary dependency 2021-05-01 03:24:50 +02:00
Tim
cb848c066c Make it into installable package 2021-05-01 03:21:30 +02:00
Alex Bewley
bce9f0d1fc Use os.path.join and update video link. 2020-11-28 20:29:29 +01:00
Alex Bewley
3f548c04e7 Correct box format in documentation for batch_iou. 2020-10-25 19:20:42 +01:00
Alex Bewley
7fc1ce2855 Update to latest skimage. 2020-08-05 22:59:38 +02:00
Alex Bewley
c1b8084812 Remove numba dependency and expose hparams flags. 2020-07-19 23:27:59 +02:00
Alex Bewley
4c0bdb0935 Make numba.jit an optional requirement. 2020-04-23 22:39:44 +02:00
Alex Bewley
2e8b22503e Update filterpy version and minor clean-up. 2020-04-23 22:26:15 +02:00
7 changed files with 81 additions and 135 deletions

4
.gitignore vendored
View file

@ -1,2 +1,6 @@
output/ output/
mot_benchmark mot_benchmark
venv
dist
build
**/*egg-info

View file

@ -2,7 +2,7 @@ SORT
===== =====
A simple online and realtime tracking algorithm for 2D multiple object tracking in video sequences. A simple online and realtime tracking algorithm for 2D multiple object tracking in video sequences.
See an example [video here](https://motchallenge.net/movies/ETH-Linthescher-SORT.mp4). See an example [video here](https://alex.bewley.ai/misc/SORT-MOT17-06-FRCNN.webm).
By Alex Bewley By Alex Bewley
@ -18,6 +18,8 @@ For your convenience, this repo also contains *Faster* RCNN detections for the M
**Also see:** **Also see:**
A new and improved version of SORT with a Deep Association Metric implemented in tensorflow is available at [https://github.com/nwojke/deep_sort](https://github.com/nwojke/deep_sort) . A new and improved version of SORT with a Deep Association Metric implemented in tensorflow is available at [https://github.com/nwojke/deep_sort](https://github.com/nwojke/deep_sort) .
This fork packages the SORT algorithm as a pip package (simple-online-realtime-tracking).
### License ### License
SORT is released under the GPL License (refer to the LICENSE file for details) to promote the open use of the tracker and future improvements. If you require a permissive license contact Alex (alex@bewley.ai). SORT is released under the GPL License (refer to the LICENSE file for details) to promote the open use of the tracker and future improvements. If you require a permissive license contact Alex (alex@bewley.ai).
@ -36,37 +38,13 @@ If you find this repo useful in your research, please consider citing:
doi={10.1109/ICIP.2016.7533003} doi={10.1109/ICIP.2016.7533003}
} }
### Installing:
### Dependencies: To install the package:
To install required dependencies run:
``` ```
$ pip install -r requirements.txt pip install simple-online-realtime-tracking==0.3
``` ```
### Demo:
To run the tracker with the provided detections:
```
$ cd path/to/sort
$ python sort.py
```
To display the results you need to:
1. Download the [2D MOT 2015 benchmark dataset](https://motchallenge.net/data/2D_MOT_2015/#download)
0. Create a symbolic link to the dataset
```
$ ln -s /path/to/MOT2015_challenge/data/2DMOT2015 mot_benchmark
```
0. Run the demo with the ```--display``` flag
```
$ python sort.py --display
```
### Main Results ### Main Results
Using the [MOT challenge devkit](https://motchallenge.net/devkit/) the method produces the following results (as described in the paper). Using the [MOT challenge devkit](https://motchallenge.net/devkit/) the method produces the following results (as described in the paper).
@ -81,12 +59,11 @@ Using the [MOT challenge devkit](https://motchallenge.net/devkit/) the method pr
KITTI-17 | 67.1 | 92.3 | 0.26 | 9 1 8 0| 38 225 9 16| 60.2 72.3 61.3 KITTI-17 | 67.1 | 92.3 | 0.26 | 9 1 8 0| 38 225 9 16| 60.2 72.3 61.3
*Overall* | 49.5 | 77.5 | 1.24 | 234 48 111 75| 3311 11660 274 499| 34.0 73.3 35.1 *Overall* | 49.5 | 77.5 | 1.24 | 234 48 111 75| 3311 11660 274 499| 34.0 73.3 35.1
### Using SORT in your own project ### Using SORT in your own project
Below is the gist of how to instantiate and update SORT. See the ['__main__'](https://github.com/abewley/sort/blob/master/sort.py#L239) section of [sort.py](https://github.com/abewley/sort/blob/master/sort.py#L239) for a complete example. Below is the gist of how to instantiate and update SORT. See the ['__main__'](https://github.com/abewley/sort/blob/master/sort.py#L239) section of [sort.py](https://github.com/abewley/sort/blob/master/sort.py#L239) for a complete example.
from sort import * from sort import Sort
#create instance of SORT #create instance of SORT
mot_tracker = Sort() mot_tracker = Sort()

6
pyproject.toml Normal file
View file

@ -0,0 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"

View file

@ -1,4 +0,0 @@
filterpy==1.4.1
numba==0.38.1
scikit-image==0.14.0
lap==0.4.0

22
setup.cfg Normal file
View file

@ -0,0 +1,22 @@
[metadata]
name = simple-online-realtime-tracking
version = 0.3
author = Alex Bewley
description = A simple online and realtime tracking algorithm for 2D multiple object tracking in video sequences
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/timstokman/sort
classifiers =
Programming Language :: Python :: 3
Operating System :: OS Independent
License :: OSI Approved :: GNU General Public License v3 (GPLv3)
[options]
package_dir =
= src
packages = sort
python_requires = >=3.8
install_requires =
filterpy==1.4.5
lapx>=0.5.0
numpy>=1.18.5

18
src/sort/__init__.py Normal file
View file

@ -0,0 +1,18 @@
"""
SORT: A Simple, Online and Realtime Tracker
Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from .sort import Sort

View file

@ -1,6 +1,6 @@
""" """
SORT: A Simple, Online and Realtime Tracker SORT: A Simple, Online and Realtime Tracker
Copyright (C) 2016 Alex Bewley alex@dynamicdetection.com Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -15,16 +15,8 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from __future__ import print_function
from numba import jit
import os import os
import numpy as np import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage import io
import glob import glob
import time import time
@ -44,23 +36,25 @@ def linear_assignment(cost_matrix):
x, y = linear_sum_assignment(cost_matrix) x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y))) return np.array(list(zip(x, y)))
@jit def iou_batch(bb_test, bb_gt):
def iou(bb_test, bb_gt):
""" """
Computes IUO between two bboxes in the form [x1,y1,x2,y2] From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
""" """
xx1 = np.maximum(bb_test[0], bb_gt[0]) bb_gt = np.expand_dims(bb_gt, 0)
yy1 = np.maximum(bb_test[1], bb_gt[1]) bb_test = np.expand_dims(bb_test, 1)
xx2 = np.minimum(bb_test[2], bb_gt[2])
yy2 = np.minimum(bb_test[3], bb_gt[3]) xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
w = np.maximum(0., xx2 - xx1) w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1) h = np.maximum(0., yy2 - yy1)
wh = w * h wh = w * h
o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1]) o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
+ (bb_gt[2] - bb_gt[0]) * (bb_gt[3] - bb_gt[1]) - wh) + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
return(o) return(o)
def convert_bbox_to_z(bbox): def convert_bbox_to_z(bbox):
""" """
@ -76,6 +70,7 @@ def convert_bbox_to_z(bbox):
r = w / float(h) r = w / float(h)
return np.array([x, y, s, r]).reshape((4, 1)) return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x,score=None): def convert_x_to_bbox(x,score=None):
""" """
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
@ -99,7 +94,7 @@ class KalmanBoxTracker(object):
Initialises a tracker using initial bounding box. Initialises a tracker using initial bounding box.
""" """
#define constant velocity model #define constant velocity model
self.kf = KalmanFilter(dim_x=7, dim_z=4, compute_log_likelihood=False) self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1,0,0,0,1,0,0],[0,1,0,0,0,1,0],[0,0,1,0,0,0,1],[0,0,0,1,0,0,0], [0,0,0,0,1,0,0],[0,0,0,0,0,1,0],[0,0,0,0,0,0,1]]) self.kf.F = np.array([[1,0,0,0,1,0,0],[0,1,0,0,0,1,0],[0,0,1,0,0,0,1],[0,0,0,1,0,0,0], [0,0,0,0,1,0,0],[0,0,0,0,0,1,0],[0,0,0,0,0,0,1]])
self.kf.H = np.array([[1,0,0,0,0,0,0],[0,1,0,0,0,0,0],[0,0,1,0,0,0,0],[0,0,0,1,0,0,0]]) self.kf.H = np.array([[1,0,0,0,0,0,0],[0,1,0,0,0,0,0],[0,0,1,0,0,0,0],[0,0,0,1,0,0,0]])
@ -117,6 +112,7 @@ class KalmanBoxTracker(object):
self.hits = 0 self.hits = 0
self.hit_streak = 0 self.hit_streak = 0
self.age = 0 self.age = 0
self.original_id = bbox[5]
def update(self,bbox): def update(self,bbox):
""" """
@ -126,6 +122,7 @@ class KalmanBoxTracker(object):
self.history = [] self.history = []
self.hits += 1 self.hits += 1
self.hit_streak += 1 self.hit_streak += 1
self.original_id = bbox[5]
self.kf.update(convert_bbox_to_z(bbox)) self.kf.update(convert_bbox_to_z(bbox))
def predict(self): def predict(self):
@ -148,6 +145,7 @@ class KalmanBoxTracker(object):
""" """
return convert_x_to_bbox(self.kf.x) return convert_x_to_bbox(self.kf.x)
def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3): def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
""" """
Assigns detections to tracked object (both represented as bounding boxes) Assigns detections to tracked object (both represented as bounding boxes)
@ -156,11 +154,8 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
""" """
if(len(trackers)==0): if(len(trackers)==0):
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int) return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
iou_matrix = np.zeros((len(detections),len(trackers)),dtype=np.float32)
for d,det in enumerate(detections): iou_matrix = iou_batch(detections, trackers)
for t,trk in enumerate(trackers):
iou_matrix[d,t] = iou(det,trk)
if min(iou_matrix.shape) > 0: if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32) a = (iou_matrix > iou_threshold).astype(np.int32)
@ -197,12 +192,13 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
class Sort(object): class Sort(object):
def __init__(self, max_age=1, min_hits=3): def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
""" """
Sets key parameters for SORT Sets key parameters for SORT
""" """
self.max_age = max_age self.max_age = max_age
self.min_hits = min_hits self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = [] self.trackers = []
self.frame_count = 0 self.frame_count = 0
@ -228,15 +224,11 @@ class Sort(object):
trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del): for t in reversed(to_del):
self.trackers.pop(t) self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks) matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks, self.iou_threshold)
# update matched trackers with assigned detections # update matched trackers with assigned detections
for m in matched: for m in matched:
self.trackers[m[1]].update(dets[m[0], :]) self.trackers[m[1]].update(dets[m[0], :])
#for t, trk in enumerate(self.trackers):
# if(t not in unmatched_trks):
# d = matched[np.where(matched[:,1]==t)[0],0]
# trk.update(dets[d,:][0])
# create and initialise new trackers for unmatched detections # create and initialise new trackers for unmatched detections
for i in unmatched_dets: for i in unmatched_dets:
@ -246,7 +238,7 @@ class Sort(object):
for trk in reversed(self.trackers): for trk in reversed(self.trackers):
d = trk.get_state()[0] d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive ret.append(np.concatenate((d,[trk.id+1],[trk.original_id])).reshape(1,-1)) # +1 as MOT benchmark requires positive
i -= 1 i -= 1
# remove dead tracklet # remove dead tracklet
if(trk.time_since_update > self.max_age): if(trk.time_since_update > self.max_age):
@ -254,72 +246,3 @@ class Sort(object):
if(len(ret)>0): if(len(ret)>0):
return np.concatenate(ret) return np.concatenate(ret)
return np.empty((0,5)) return np.empty((0,5))
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='SORT demo')
parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]',action='store_true')
parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
args = parser.parse_args()
return args
if __name__ == '__main__':
# all train
args = parse_args()
display = args.display
phase = args.phase
total_time = 0.0
total_frames = 0
colours = np.random.rand(32, 3) #used only for display
if(display):
if not os.path.exists('mot_benchmark'):
print('\n\tERROR: mot_benchmark link not found!\n\n Create a symbolic link to the MOT benchmark\n (https://motchallenge.net/data/2D_MOT_2015/#download). E.g.:\n\n $ ln -s /path/to/MOT2015_challenge/2DMOT2015 mot_benchmark\n\n')
exit()
plt.ion()
fig = plt.figure()
ax1 = fig.add_subplot(111, aspect='equal')
if not os.path.exists('output'):
os.makedirs('output')
pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
for seq_dets_fn in glob.glob(pattern):
mot_tracker = Sort() #create instance of the SORT tracker
seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
seq = seq_dets_fn[pattern.find('*'):].split('/')[0]
with open('output/%s.txt'%(seq),'w') as out_file:
print("Processing %s."%(seq))
for frame in range(int(seq_dets[:,0].max())):
frame += 1 #detection and frame numbers begin at 1
dets = seq_dets[seq_dets[:, 0]==frame, 2:7]
dets[:, 2:4] += dets[:, 0:2] #convert to [x1,y1,w,h] to [x1,y1,x2,y2]
total_frames += 1
if(display):
fn = 'mot_benchmark/%s/%s/img1/%06d.jpg'%(phase, seq, frame)
im =io.imread(fn)
ax1.imshow(im)
plt.title(seq + ' Tracked Targets')
start_time = time.time()
trackers = mot_tracker.update(dets)
cycle_time = time.time() - start_time
total_time += cycle_time
for d in trackers:
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1'%(frame,d[4],d[0],d[1],d[2]-d[0],d[3]-d[1]),file=out_file)
if(display):
d = d.astype(np.int32)
ax1.add_patch(patches.Rectangle((d[0],d[1]),d[2]-d[0],d[3]-d[1],fill=False,lw=3,ec=colours[d[4]%32,:]))
#ax1.set_adjustable('box-forced')
if(display):
fig.canvas.flush_events()
plt.draw()
ax1.cla()
print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
if(display):
print("Note: to get real runtime results run without the option: --display")