Tweak tracker to get better tracks
This commit is contained in:
parent
abc80727da
commit
389da6701f
9 changed files with 578 additions and 96 deletions
11
custom_bytetrack.yaml
Normal file
11
custom_bytetrack.yaml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
|
||||||
|
|
||||||
|
tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
|
||||||
|
track_high_thresh: 0.05 # threshold for the first association
|
||||||
|
track_low_thresh: 0.01 # threshold for the second association
|
||||||
|
new_track_thresh: 0.1 # threshold for init new track if the detection does not match any tracks
|
||||||
|
track_buffer: 35 # buffer to calculate the time when to remove tracks
|
||||||
|
match_thresh: 0.9 # threshold for matching tracks
|
||||||
|
fuse_score: True # Whether to fuse confidence scores with the iou distances before matching
|
||||||
|
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
87
poetry.lock
generated
87
poetry.lock
generated
|
@ -219,6 +219,25 @@ webencodings = "*"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
css = ["tinycss2 (>=1.1.0,<1.3)"]
|
css = ["tinycss2 (>=1.1.0,<1.3)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytetracker"
|
||||||
|
version = "0.3.2"
|
||||||
|
description = "Packaged version of the ByteTrack repository"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.5"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
lapx = ">=0.5.8"
|
||||||
|
scipy = ">=1.9.3"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "https://github.com/rubenvandeven/bytetrack-pip"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "7053b946af8641581999b70230ac6260d37365ae"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cachetools"
|
name = "cachetools"
|
||||||
version = "5.3.2"
|
version = "5.3.2"
|
||||||
|
@ -1366,6 +1385,72 @@ files = [
|
||||||
{file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
|
{file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lapx"
|
||||||
|
version = "0.5.11"
|
||||||
|
description = "Linear Assignment Problem solver (LAPJV/LAPMOD)."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ad2e100a81387e958cbb2b25b09f5f4b4c7af1ba39313b4bbda9965ee85b43a2"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:523baec5c6a554946c843877802fbefe2230179854134e6b0f281e0f47f5b342"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6eba714681ea77c834a740eaa43640a3072341a7418e95fb6d5aa4380b6b2069"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:797cda3b06e835fa6ee3e245289c03ec29d110e6dd8e333db71483f5bbd32129"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:29b1447d32e00a89afa90627c7e806bd7eb8e21e4d559749f4a0fc4e47989f64"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:71caffa000782ab265f76e47aa018368b3738a5be5d62581a870dd6e68417703"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-win_amd64.whl", hash = "sha256:a3c1c20c7d80fa7b6eca0ea9e10966c93ccdaf4d5286e677b199bf021a889b18"},
|
||||||
|
{file = "lapx-0.5.11-cp310-cp310-win_arm64.whl", hash = "sha256:995aea6268f0a519e536f009f44144c4f4066e0724d19239fbcc1c1ab082d7c0"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a665dbc34f04fe21cdb798be1c003fa0d07b0e27e9020487e7626714bad4b8a"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5262a868f8802e368ecb470444c830e95b960a2a3763764dd3370680a466684e"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824a8c50e191343096d48164f2a0e9a5d466e8d20dd8e3eff1a3c1082b4d2b2"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d648b706b9a22255a028c72f4849c97ba0d754e03d74009ff447b26dbbd9bb59"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f6a0b789023d80b0f5ba3f20d83c9601d02e03abe8ae209ada3a77f304d91fff"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:44b3c36d52db1eea6eb0c46795440adddbdfe39c22ddb0f2a86db40ab963a798"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-win_amd64.whl", hash = "sha256:2e21c0162b426034ff545cedb86713b642f1e7335fda43605b330ca28a107d13"},
|
||||||
|
{file = "lapx-0.5.11-cp311-cp311-win_arm64.whl", hash = "sha256:2ea0c5dbf62de0612337c4c0a3f1b5ac8cc4fabfb9f68fd1c76612e2d873a28c"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:18c0c2e7f7ca76527468d98b99c54cf339ea040512392de6d20d8582235b43bc"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ab21b4a45bf975b890ba0364bc354652e3ebb548fb69f23cca4c337ce0e72b"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afc0233eaef80c04f88449f1bfbe059bfb5556458bc46de54d080e3236db6588"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d48adb34670c1c548cc39c40831e44dd57d7fe640320d86d61e3d2bf179f1f79"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:28fb1d9f076aff6b98abbc1287aa453d6fd3be0b0a5039adb2a822e2246a2bdd"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:70abdb835fcfad72856f3fa27d8233b9b267a9efe534a7fa8ede456b876a696d"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-win_amd64.whl", hash = "sha256:905fb018952b7b6ea9ef5ac8f5600e525c2545a679d11951bfc0d7e861efbe31"},
|
||||||
|
{file = "lapx-0.5.11-cp312-cp312-win_arm64.whl", hash = "sha256:02343669611038ec2826c4110d953235397d25e5eff01a5d2cbd9986c3492691"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d517ce6b0d17af31c71d9230b03f2c09cb7a701d20b3ffbe02c4366ed91c3b85"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6078aa84f768585c6121fc1d8767d6136d9af34ccd48525174ee488c86a59964"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:206382e2e942e1d968cb194149d0693a293052c16d0016505788b795818bab21"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:986f6eaa5a21d5b90869d54a0b7e11e9e532cd8938b3980cbd43d43154d96ac1"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:394555e2245cd6aa2ad79cea58c8a5cc73f6c6f79b85f497020e6a978c346329"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dcd1a608c19e14d6d7fd47c885c5f3c68ce4e08f3e8ec2eecbe32fc026b0d1de"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-win_amd64.whl", hash = "sha256:37b6e5f4f04c477a49f7d0780fbe76513c2d5e183bcf5005396c96bfd3de15d6"},
|
||||||
|
{file = "lapx-0.5.11-cp313-cp313-win_arm64.whl", hash = "sha256:c6c84a46f94829c6d992cce7fe747bacb971bca8cb9e77ea6ff80dfbc4fea6e2"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cb7175a5527a46cd6b9212da766539f256790e93747ba5503bd0507e3fd19e3c"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33769788e57b5a7fa0951a884c7e5f8a381792427357976041c1d4c3ae75ddcd"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc015da5a657dbb8822fb47f4702f5bf0b67bafff6f5aed98dad3a6204572e40"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:9b322e6e0685340f5a10ada401065165fa73e84c2db93ba17945e8e119bd17f5"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:a64bc3da09c5925efaff59d20dcfbc3febac64fd1fcc263604f7d4ccdd0e2a75"},
|
||||||
|
{file = "lapx-0.5.11-cp37-cp37m-win_amd64.whl", hash = "sha256:e6d98d31bdf7131a0ec9967068885c86357cc77cf7883d1a1335a48b24e537bb"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c381563bc713a6fc8a281698893a8115abb34363929105142d66e592252af196"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c4e10ec35b539984684ef9bfdbd0f763a441225e8a9cff5d7081b24795dd419"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6583c3a5a47dbfb360d312e4bb3bde509d519886da74161a46ad77653fd18dcb"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3963a87d44bc174eee423239dff952a922b3a8e30fbc514c00fab361f464c74"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:5e8bf4b68f45e378ce7fc68dd407ea210926881e8f25a08dc18beb4a4a7cced0"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a8ee4578a898148110dd8ee877370ee99b6309d7b600bde60e204efd93858c0d"},
|
||||||
|
{file = "lapx-0.5.11-cp38-cp38-win_amd64.whl", hash = "sha256:46b8373f25c1ea85b236fc20183b077efd33112de57f57feccc61f3541f0d8f0"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f3c88ff301e7cf9a22a7e276e13a9154f0813a2c3f9c3619f3785def851b5f4c"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d07df31b37c92643b9fb6aceeafb53a13375d95dd9cbd2454d83f03941d7e137"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cbf42869528ac80d87e1d9019d594a619fd77c26fba0eb8e5242f6657197a57"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5f38aa911b4773b44d2da0046aad6ded8f08853cbc1ec60a2363b46371e16f8"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d04e5bd530fb04a22004c3ba1a31033c4e3deffe18fb43778fc234a627a55397"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845ca48d0c313113f7ae53420d1e5af87c01355fe14e5624c2a934e1004ef43b"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-win_amd64.whl", hash = "sha256:0eb81393d0edd089b61de1c3c25895b8ed39bc8d91c766a6e06b761194e79894"},
|
||||||
|
{file = "lapx-0.5.11-cp39-cp39-win_arm64.whl", hash = "sha256:e12b3e0d9943e92a0cd3af7a6fa5fb8fab3aa018897beda6d02647d9ce708d5a"},
|
||||||
|
{file = "lapx-0.5.11.tar.gz", hash = "sha256:d925d4a11f436ef0f9e9684378a44e4375aa9c868b22e2f51e6ff15a3362685f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.21.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markdown"
|
name = "markdown"
|
||||||
version = "3.5.1"
|
version = "3.5.1"
|
||||||
|
@ -3640,4 +3725,4 @@ watchdog = ["watchdog (>=2.3)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10,<3.12,"
|
python-versions = "^3.10,<3.12,"
|
||||||
content-hash = "65e5518fd7bc216a72206768f621251e8f2ade9ec07ef0af1d3926492fc6fd70"
|
content-hash = "bf4feafd4afa6ceb39a1c599e3e7cdc84afbe11ab1672b49e5de99ad44568b08"
|
||||||
|
|
|
@ -8,6 +8,7 @@ readme = "README.md"
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
trapserv = "trap.plumber:start"
|
trapserv = "trap.plumber:start"
|
||||||
tracker = "trap.tools:tracker_preprocess"
|
tracker = "trap.tools:tracker_preprocess"
|
||||||
|
compare = "trap.tools:tracker_compare"
|
||||||
process_data = "trap.process_data:main"
|
process_data = "trap.process_data:main"
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +38,7 @@ pyglet = "^2.0.15"
|
||||||
pyglet-cornerpin = "^0.3.0"
|
pyglet-cornerpin = "^0.3.0"
|
||||||
opencv-python = {file="./opencv_python-4.10.0.84-cp310-cp310-linux_x86_64.whl"}
|
opencv-python = {file="./opencv_python-4.10.0.84-cp310-cp310-linux_x86_64.whl"}
|
||||||
setproctitle = "^1.3.3"
|
setproctitle = "^1.3.3"
|
||||||
|
bytetracker = { git = "https://github.com/rubenvandeven/bytetrack-pip" }
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
|
@ -30,6 +30,8 @@ from trap.preview_renderer import DrawnTrack, PROJECTION_IMG, PROJECTION_MAP
|
||||||
|
|
||||||
logger = logging.getLogger("trap.renderer")
|
logger = logging.getLogger("trap.renderer")
|
||||||
|
|
||||||
|
COLOR_PRIMARY = (0,0,0,255)
|
||||||
|
|
||||||
class AnimationRenderer:
|
class AnimationRenderer:
|
||||||
def __init__(self, config: Namespace, is_running: BaseEvent):
|
def __init__(self, config: Namespace, is_running: BaseEvent):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -96,7 +98,7 @@ class AnimationRenderer:
|
||||||
self.window.push_handlers(self.pins)
|
self.window.push_handlers(self.pins)
|
||||||
|
|
||||||
pyglet.gl.glClearColor(255,255,255,255)
|
pyglet.gl.glClearColor(255,255,255,255)
|
||||||
self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=(255,255,255,255))
|
self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=COLOR_PRIMARY)
|
||||||
self.fps_display.label.x = self.window.width - 50
|
self.fps_display.label.x = self.window.width - 50
|
||||||
self.fps_display.label.y = self.window.height - 17
|
self.fps_display.label.y = self.window.height - 17
|
||||||
self.fps_display.label.bold = False
|
self.fps_display.label.bold = False
|
||||||
|
@ -117,11 +119,11 @@ class AnimationRenderer:
|
||||||
|
|
||||||
if self.config.render_debug_shapes:
|
if self.config.render_debug_shapes:
|
||||||
self.debug_lines = [
|
self.debug_lines = [
|
||||||
pyglet.shapes.Line(1370, self.config.camera.h-360, 1380, 670-360, 2, (255,255,255,255), batch=self.batch_overlay),#v
|
pyglet.shapes.Line(1370, self.config.camera.h-360, 1380, 670-360, 2, COLOR_PRIMARY, batch=self.batch_overlay),#v
|
||||||
pyglet.shapes.Line(0, 660-360, 1380, 670-360, 2, (255,255,255,255), batch=self.batch_overlay), #h
|
pyglet.shapes.Line(0, 660-360, 1380, 670-360, 2, COLOR_PRIMARY, batch=self.batch_overlay), #h
|
||||||
pyglet.shapes.Line(1140, 760-360, 1140, 675-360, 2, (255,255,255,255), batch=self.batch_overlay), #h
|
pyglet.shapes.Line(1140, 760-360, 1140, 675-360, 2, COLOR_PRIMARY, batch=self.batch_overlay), #h
|
||||||
pyglet.shapes.Line(540, 760-360,540, 675-360, 2, (255,255,255,255), batch=self.batch_overlay), #v
|
pyglet.shapes.Line(540, 760-360,540, 675-360, 2, COLOR_PRIMARY, batch=self.batch_overlay), #v
|
||||||
pyglet.shapes.Line(0, 770-360, 1380, 770-360, 2, (255,255,255,255), batch=self.batch_overlay), #h
|
pyglet.shapes.Line(0, 770-360, 1380, 770-360, 2, COLOR_PRIMARY, batch=self.batch_overlay), #h
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -197,7 +199,7 @@ class AnimationRenderer:
|
||||||
self.gradientLine = GradientLine
|
self.gradientLine = GradientLine
|
||||||
|
|
||||||
def init_labels(self):
|
def init_labels(self):
|
||||||
base_color = (255,)*4
|
base_color = COLOR_PRIMARY
|
||||||
color_predictor = (255,255,0, 255)
|
color_predictor = (255,255,0, 255)
|
||||||
color_info = (255,0, 255, 255)
|
color_info = (255,0, 255, 255)
|
||||||
color_tracker = (0,255, 255, 255)
|
color_tracker = (0,255, 255, 255)
|
||||||
|
@ -278,7 +280,7 @@ class AnimationRenderer:
|
||||||
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
|
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
|
||||||
# transform to flipped coordinate system for pyglet
|
# transform to flipped coordinate system for pyglet
|
||||||
self.video_sprite.y = self.window.height - self.video_sprite.height
|
self.video_sprite.y = self.window.height - self.video_sprite.height
|
||||||
self.video_sprite.opacity = 10
|
self.video_sprite.opacity = 70
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
# idx = frame.index if frame else "NONE"
|
# idx = frame.index if frame else "NONE"
|
||||||
# logger.debug(f"reuse video frame {idx}")
|
# logger.debug(f"reuse video frame {idx}")
|
||||||
|
|
|
@ -4,7 +4,7 @@ import types
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from trap.tracker import DETECTORS
|
from trap.tracker import DETECTORS, TRACKER_BYTETRACK, TRACKERS
|
||||||
from trap.frame_emitter import Camera
|
from trap.frame_emitter import Camera
|
||||||
|
|
||||||
from pyparsing import Optional
|
from pyparsing import Optional
|
||||||
|
@ -294,6 +294,11 @@ tracker_parser.add_argument("--detector",
|
||||||
help="Specify the detector to use",
|
help="Specify the detector to use",
|
||||||
type=str,
|
type=str,
|
||||||
choices=DETECTORS)
|
choices=DETECTORS)
|
||||||
|
tracker_parser.add_argument("--tracker",
|
||||||
|
help="Specify the detector to use",
|
||||||
|
type=str,
|
||||||
|
default=TRACKER_BYTETRACK,
|
||||||
|
choices=TRACKERS)
|
||||||
tracker_parser.add_argument("--smooth-tracks",
|
tracker_parser.add_argument("--smooth-tracks",
|
||||||
help="Smooth the tracker tracks before sending them to the predictor",
|
help="Smooth the tracker tracks before sending them to the predictor",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|
|
@ -8,13 +8,15 @@ from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, List, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
import os
|
import os
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
|
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
|
||||||
|
from bytetracker.byte_tracker import STrack as ByteTrackTrack
|
||||||
|
from bytetracker.basetrack import TrackState as ByteTrackTrackState
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -51,6 +53,17 @@ class DetectionState(IntFlag):
|
||||||
return cls.Confirmed
|
return cls.Confirmed
|
||||||
raise RuntimeError("Should not run into Deleted entries here")
|
raise RuntimeError("Should not run into Deleted entries here")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytetrack_track(cls, track: ByteTrackTrack):
|
||||||
|
if track.state == ByteTrackTrackState.New:
|
||||||
|
return cls.Tentative
|
||||||
|
if track.state == ByteTrackTrackState.Lost:
|
||||||
|
return cls.Lost
|
||||||
|
# if track.time_since_update > 0:
|
||||||
|
if track.state == ByteTrackTrackState.Tracked:
|
||||||
|
return cls.Confirmed
|
||||||
|
raise RuntimeError("Should not run into Deleted entries here")
|
||||||
|
|
||||||
class Camera:
|
class Camera:
|
||||||
def __init__(self, mtx, dist, w, h, H):
|
def __init__(self, mtx, dist, w, h, H):
|
||||||
self.mtx = mtx
|
self.mtx = mtx
|
||||||
|
@ -71,13 +84,19 @@ class Detection:
|
||||||
conf: float # object detector probablity
|
conf: float # object detector probablity
|
||||||
state: DetectionState
|
state: DetectionState
|
||||||
frame_nr: int
|
frame_nr: int
|
||||||
|
det_class: str
|
||||||
|
|
||||||
def get_foot_coords(self) -> list[tuple[float, float]]:
|
def get_foot_coords(self) -> list[tuple[float, float]]:
|
||||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
def from_deepsort(cls, dstrack: DeepsortTrack, frame_nr: int):
|
||||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack))
|
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack), frame_nr, dstrack.det_class)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytetrack(cls, bstrack: ByteTrackTrack, frame_nr: int):
|
||||||
|
return cls(bstrack.track_id, *bstrack.tlwh, bstrack.score, DetectionState.from_bytetrack_track(bstrack), frame_nr, bstrack.cls)
|
||||||
|
|
||||||
def get_scaled(self, scale: float = 1):
|
def get_scaled(self, scale: float = 1):
|
||||||
if scale == 1:
|
if scale == 1:
|
||||||
|
@ -90,7 +109,9 @@ class Detection:
|
||||||
self.w*scale,
|
self.w*scale,
|
||||||
self.h*scale,
|
self.h*scale,
|
||||||
self.conf,
|
self.conf,
|
||||||
self.state)
|
self.state,
|
||||||
|
self.frame_nr,
|
||||||
|
self.det_class)
|
||||||
|
|
||||||
def to_ltwh(self):
|
def to_ltwh(self):
|
||||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||||
|
@ -107,7 +128,7 @@ class Track:
|
||||||
and acceleration.
|
and acceleration.
|
||||||
"""
|
"""
|
||||||
track_id: str = None
|
track_id: str = None
|
||||||
history: [Detection] = field(default_factory=lambda: [])
|
history: List[Detection] = field(default_factory=lambda: [])
|
||||||
predictor_history: Optional[list] = None # in image space
|
predictor_history: Optional[list] = None # in image space
|
||||||
predictions: Optional[list] = None
|
predictions: Optional[list] = None
|
||||||
|
|
||||||
|
|
|
@ -156,7 +156,7 @@ class DrawnTrack:
|
||||||
if ci >= len(self.shapes):
|
if ci >= len(self.shapes):
|
||||||
# TODO: add color2
|
# TODO: add color2
|
||||||
line = self.renderer.gradientLine(x, y, x2, y2, 3, color, color, batch=self.renderer.batch_anim)
|
line = self.renderer.gradientLine(x, y, x2, y2, 3, color, color, batch=self.renderer.batch_anim)
|
||||||
line = pyglet.shapes.Arc(x2, y2, 10, thickness=3, color=color, batch=self.renderer.batch_anim)
|
line = pyglet.shapes.Arc(x2, y2, 10, thickness=2, color=color, batch=self.renderer.batch_anim)
|
||||||
line.opacity = 20
|
line.opacity = 20
|
||||||
self.shapes.append(line)
|
self.shapes.append(line)
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class DrawnTrack:
|
||||||
line = self.shapes[ci-1]
|
line = self.shapes[ci-1]
|
||||||
line.x, line.y = x, y
|
line.x, line.y = x, y
|
||||||
line.x2, line.y2 = x2, y2
|
line.x2, line.y2 = x2, y2
|
||||||
line.radius = int(exponentialDecay(line.radius, 2, 3, dt))
|
line.radius = int(exponentialDecay(line.radius, 1.5, 3, dt))
|
||||||
line.color = color
|
line.color = color
|
||||||
line.opacity = int(exponentialDecay(line.opacity, 180, 8, dt))
|
line.opacity = int(exponentialDecay(line.opacity, 180, 8, dt))
|
||||||
|
|
||||||
|
|
325
trap/tools.py
325
trap/tools.py
|
@ -1,66 +1,95 @@
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
from tempfile import mktemp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import trap.tracker
|
||||||
from trap.config import parser
|
from trap.config import parser
|
||||||
from trap.frame_emitter import video_src_from_config, Frame
|
from trap.frame_emitter import Detection, DetectionState, video_src_from_config, Frame
|
||||||
from trap.tracker import DETECTOR_YOLOv8, _yolov8_track, Track, TrainingDataWriter
|
from trap.tracker import DETECTOR_YOLOv8, Smoother, _yolov8_track, Track, TrainingDataWriter, Tracker
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import cv2
|
import cv2
|
||||||
from typing import List, Iterable
|
from typing import List, Iterable, Optional
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.engine.results import Results as YOLOResult
|
from ultralytics.engine.results import Results as YOLOResult
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
config = parser.parse_args()
|
|
||||||
|
|
||||||
logger = logging.getLogger('tools')
|
logger = logging.getLogger('tools')
|
||||||
|
|
||||||
def tracker_preprocess():
|
|
||||||
video_srcs = video_src_from_config(config)
|
|
||||||
if not hasattr(config, "H"):
|
|
||||||
print("Set homography file with --homography param")
|
|
||||||
return
|
|
||||||
|
|
||||||
if config.detector != DETECTOR_YOLOv8:
|
|
||||||
print("Only YOLO for now...")
|
|
||||||
return
|
|
||||||
|
|
||||||
model = YOLO('EXPERIMENTS/yolov8x.pt')
|
|
||||||
|
|
||||||
with TrainingDataWriter(config.save_for_training) as writer:
|
|
||||||
for video_nr, video_path in enumerate(video_srcs):
|
class FrameGenerator():
|
||||||
|
def __init__(self, config):
|
||||||
|
self.video_srcs = video_src_from_config(config)
|
||||||
|
self.config = config
|
||||||
|
if not hasattr(config, "H"):
|
||||||
|
raise RuntimeError("Set homography file with --homography param")
|
||||||
|
|
||||||
|
# store current position
|
||||||
|
self.video_path = None
|
||||||
|
self.video_nr = None
|
||||||
|
self.frame_count = None
|
||||||
|
self.frame_idx = None
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
n = 0
|
||||||
|
for video_nr, video_path in enumerate(self.video_srcs):
|
||||||
|
self.video_path = video_path
|
||||||
|
self.video_nr = video_path
|
||||||
logger.info(f"Play from '{str(video_path)}'")
|
logger.info(f"Play from '{str(video_path)}'")
|
||||||
video = cv2.VideoCapture(str(video_path))
|
video = cv2.VideoCapture(str(video_path))
|
||||||
fps = video.get(cv2.CAP_PROP_FPS)
|
fps = video.get(cv2.CAP_PROP_FPS)
|
||||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
self.frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||||
i = 0
|
self.frame_idx = 0
|
||||||
if config.video_offset:
|
if self.config.video_offset:
|
||||||
logger.info(f"Start at frame {config.video_offset}")
|
logger.info(f"Start at frame {self.config.video_offset}")
|
||||||
video.set(cv2.CAP_PROP_POS_FRAMES, config.video_offset)
|
video.set(cv2.CAP_PROP_POS_FRAMES, self.config.video_offset)
|
||||||
i = config.video_offset
|
self.frame_idx = self.config.video_offset
|
||||||
|
|
||||||
bar = tqdm.tqdm()
|
|
||||||
tracks = defaultdict(lambda: Track())
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
bar.update()
|
|
||||||
ret, img = video.read()
|
ret, img = video.read()
|
||||||
i+=1
|
self.frame_idx+=1
|
||||||
|
n+=1
|
||||||
|
|
||||||
# seek to 0 if video has finished. Infinite loop
|
# seek to 0 if video has finished. Infinite loop
|
||||||
if not ret:
|
if not ret:
|
||||||
# now loading multiple files
|
# now loading multiple files
|
||||||
break
|
break
|
||||||
|
|
||||||
frame = Frame(index=bar.n, img=img, H=config.H, camera=config.camera)
|
frame = Frame(index=n, img=img, H=self.config.H, camera=self.config.camera)
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
|
||||||
|
def tracker_preprocess():
|
||||||
|
|
||||||
|
config = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
tracker = Tracker(config)
|
||||||
|
# model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
|
|
||||||
|
with TrainingDataWriter(config.save_for_training) as writer:
|
||||||
|
|
||||||
|
bar = tqdm.tqdm()
|
||||||
|
tracks = defaultdict(lambda: Track())
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
frames = FrameGenerator(config)
|
||||||
|
for frame in frames:
|
||||||
|
bar.update()
|
||||||
|
|
||||||
detections = _yolov8_track(frame, model, classes=[0])
|
detections = tracker.track_frame(frame)
|
||||||
total += len(detections)
|
total += len(detections)
|
||||||
# detections = _yolov8_track(frame, model, imgsz=1440, classes=[0])
|
# detections = _yolov8_track(frame, model, imgsz=1440, classes=[0])
|
||||||
|
|
||||||
bar.set_description(f"[{video_nr}/{len(video_srcs)}] [{i}/{frame_count}] {str(video_path)} -- Detections {len(detections)}: {[d.track_id for d in detections]} (so far {total})")
|
bar.set_description(f"[{frames.video_nr}/{len(frames.video_srcs)}] [{frames.frame_idx}/{frames.frame_count}] {str(frames.video_path)} -- Detections {len(detections)}: {[d.track_id for d in detections]} (so far {total})")
|
||||||
|
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
track = tracks[detection.track_id]
|
track = tracks[detection.track_id]
|
||||||
|
@ -72,4 +101,232 @@ def tracker_preprocess():
|
||||||
|
|
||||||
writer.add(frame, active_tracks.values())
|
writer.add(frame, active_tracks.values())
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|
||||||
|
bgr_colors = [
|
||||||
|
(255, 0, 0),
|
||||||
|
(0, 255, 0),
|
||||||
|
(0, 0, 255),
|
||||||
|
(0, 255, 255),
|
||||||
|
]
|
||||||
|
|
||||||
|
def detection_color(detection: Detection, i):
|
||||||
|
return bgr_colors[i % len(bgr_colors)] if detection.state != DetectionState.Lost else (100,100,100)
|
||||||
|
|
||||||
|
def to_point(coord):
|
||||||
|
return (int(coord[0]), int(coord[1]))
|
||||||
|
|
||||||
|
def tracker_compare():
|
||||||
|
config = parser.parse_args()
|
||||||
|
|
||||||
|
trackers: List[Tracker] = []
|
||||||
|
# TODO, support all tracker.DETECTORS
|
||||||
|
for tracker_id in [
|
||||||
|
trap.tracker.DETECTOR_YOLOv8,
|
||||||
|
# trap.tracker.DETECTOR_MASKRCNN,
|
||||||
|
# trap.tracker.DETECTOR_RETINANET,
|
||||||
|
trap.tracker.DETECTOR_FASTERRCNN,
|
||||||
|
]:
|
||||||
|
tracker_config = Namespace(**vars(config))
|
||||||
|
tracker_config.detector = tracker_id
|
||||||
|
trackers.append(Tracker(tracker_config))
|
||||||
|
|
||||||
|
|
||||||
|
frames = FrameGenerator(config)
|
||||||
|
bar = tqdm.tqdm(frames)
|
||||||
|
cv2.namedWindow("frame", cv2.WND_PROP_FULLSCREEN)
|
||||||
|
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
|
||||||
|
|
||||||
|
for frame in bar:
|
||||||
|
|
||||||
|
# frame.img = cv2.undistort(frame.img, config.camera.mtx, config.camera.dist, None, config.camera.newcameramtx) # try to undistort for better detections, seems not to matter at all
|
||||||
|
trackers_detections = [(t, t.track_frame(frame)) for t in trackers]
|
||||||
|
|
||||||
|
for i, tracker in enumerate(trackers):
|
||||||
|
cv2.putText(frame.img, tracker.config.detector, (10,30*(i+1)), cv2.FONT_HERSHEY_DUPLEX, 1, color=bgr_colors[i % len(bgr_colors)])
|
||||||
|
|
||||||
|
for i, (tracker, detections) in enumerate(trackers_detections):
|
||||||
|
|
||||||
|
for track_id in tracker.tracks:
|
||||||
|
history = tracker.tracks[track_id].history
|
||||||
|
cv2.putText(frame.img, f"{track_id}", to_point(history[0].get_foot_coords()), cv2.FONT_HERSHEY_DUPLEX, 1, color=bgr_colors[i % len(bgr_colors)])
|
||||||
|
for j in range(len(history)-1):
|
||||||
|
a = history[j]
|
||||||
|
b = history[j+1]
|
||||||
|
color = detection_color(b, i)
|
||||||
|
cv2.line(frame.img, to_point(a.get_foot_coords()), to_point(b.get_foot_coords()), color, 1)
|
||||||
|
for detection in detections:
|
||||||
|
color = color = detection_color(detection, i)
|
||||||
|
l, t, r, b = detection.to_ltrb()
|
||||||
|
cv2.rectangle(frame.img, (l, t), (r,b), color)
|
||||||
|
cv2.putText(frame.img, f"{detection.track_id}", (l, b+10), cv2.FONT_HERSHEY_DUPLEX, 1, color=color)
|
||||||
|
conf = f"{detection.conf:.3f}" if detection.conf is not None else "None"
|
||||||
|
cv2.putText(frame.img, f"{detection.det_class} - {conf}", (l, t), cv2.FONT_HERSHEY_DUPLEX, .7, color=color)
|
||||||
|
cv2.imshow('frame',cv2.resize(frame.img, (1920, 1080)))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
bar.set_description(f"[{frames.video_nr}/{len(frames.video_srcs)}] [{frames.frame_idx}/{frames.frame_count}] {str(frames.video_path)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_missing_frames(data: pd.DataFrame):
|
||||||
|
missing=0
|
||||||
|
old_size=len(data)
|
||||||
|
# slow way to append missing steps to the dataset
|
||||||
|
for ind, row in tqdm.tqdm(data.iterrows()):
|
||||||
|
if row['diff'] > 1:
|
||||||
|
for s in range(1, int(row['diff'])):
|
||||||
|
# add as many entries as missing
|
||||||
|
missing += 1
|
||||||
|
data.loc[len(data)] = [row['frame_id']-s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, 1]
|
||||||
|
# new_frame = [data.loc[ind-1]['frame_id']+s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
# data.loc[len(data)] = new_frame
|
||||||
|
|
||||||
|
logger.info(f'was:{old_size} added:{missing}, new length: {len(data)}')
|
||||||
|
|
||||||
|
# now sort, so that the added data is in the right place
|
||||||
|
data.sort_values(by=['track_id', 'frame_id'], inplace=True)
|
||||||
|
|
||||||
|
df=data.copy()
|
||||||
|
df = df.groupby('track_id').apply(lambda group: group.interpolate(method='linear'))
|
||||||
|
df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# update diff, shouldnow be 1 | NaN
|
||||||
|
data['diff'] = data.groupby(['track_id'])['frame_id'].diff()
|
||||||
|
|
||||||
|
# data = df
|
||||||
|
return df
|
||||||
|
|
||||||
|
def smooth(data: pd.DataFrame):
|
||||||
|
|
||||||
|
df=data.copy()
|
||||||
|
if 'x_raw' not in df:
|
||||||
|
df['x_raw'] = df['x']
|
||||||
|
if 'y_raw' not in df:
|
||||||
|
df['y_raw'] = df['y']
|
||||||
|
|
||||||
|
print("Running smoother")
|
||||||
|
# print(df)
|
||||||
|
# from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
|
||||||
|
smoother = Smoother(convolution=False)
|
||||||
|
def smoothing(data):
|
||||||
|
# smoother = ConvolutionSmoother(window_len=SMOOTHING_WINDOW, window_type='ones', copy=None)
|
||||||
|
return smoother.smooth(data).tolist()
|
||||||
|
# df=df.assign(smooth_data=smoother.smooth_data[0])
|
||||||
|
# return smoother.smooth_data[0].tolist()
|
||||||
|
|
||||||
|
# operate smoothing per axis
|
||||||
|
print("smooth x")
|
||||||
|
df['x'] = df.groupby('track_id')['x_raw'].transform(smoothing)
|
||||||
|
print("smooth y")
|
||||||
|
df['y'] = df.groupby('track_id')['y_raw'].transform(smoothing)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def load_tracks_from_csv(file: Path, fps: float, grid_size: Optional[int] = None, sample: Optional[int] = None):
|
||||||
|
cache_file = Path('/tmp/load_tracks-smooth-' + file.name)
|
||||||
|
if cache_file.exists():
|
||||||
|
data = pd.read_pickle(cache_file)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# grid_size is in points per meter
|
||||||
|
# sample: sample to every n-th point. Thus sample=5 converts 12fps to 2.4fps, and 4 to 3fps
|
||||||
|
data = pd.read_csv(file, delimiter="\t", index_col=False, header=None)
|
||||||
|
# l,t,w,h: image space (pixels)
|
||||||
|
# x,y: world space (meters or cm depending on homography)
|
||||||
|
data.columns = ['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state']
|
||||||
|
data['frame_id'] = pd.to_numeric(data['frame_id'], downcast='integer')
|
||||||
|
data['frame_id'] = data['frame_id'] // 10 # compatibility with Trajectron++
|
||||||
|
|
||||||
|
data.sort_values(by=['track_id', 'frame_id'],inplace=True)
|
||||||
|
|
||||||
|
data.set_index(['track_id', 'frame_id'])
|
||||||
|
|
||||||
|
# cm to meter
|
||||||
|
data['x'] = data['x']/100
|
||||||
|
data['y'] = data['y']/100
|
||||||
|
|
||||||
|
if grid_size is not None:
|
||||||
|
data['x'] = (data['x']*grid_size).round() / grid_size
|
||||||
|
data['y'] = (data['y']*grid_size).round() / grid_size
|
||||||
|
|
||||||
|
data['diff'] = data.groupby(['track_id'])['frame_id'].diff() #.fillna(0)
|
||||||
|
data['diff'] = pd.to_numeric(data['diff'], downcast='integer')
|
||||||
|
|
||||||
|
data = interpolate_missing_frames(data)
|
||||||
|
|
||||||
|
|
||||||
|
data = smooth(data)
|
||||||
|
data.to_pickle(cache_file)
|
||||||
|
|
||||||
|
|
||||||
|
if sample is not None:
|
||||||
|
print(f"Samping 1/{sample}, of {data.shape[0]} items")
|
||||||
|
data["idx_in_track"] = data.groupby(['track_id']).cumcount() # create index in group
|
||||||
|
groups = data.groupby(['track_id'])
|
||||||
|
# print(groups, data)
|
||||||
|
# selection = groups['idx_in_track'].apply(lambda x: x % sample == 0)
|
||||||
|
# print(selection)
|
||||||
|
selection = data["idx_in_track"].apply(lambda x: x % sample == 0)
|
||||||
|
# data = data[selection]
|
||||||
|
data = data.loc[selection].copy() # avoid errors
|
||||||
|
|
||||||
|
# # convert from e.g. 12Hz, to 2.4Hz (1/5)
|
||||||
|
# sampled_groups = []
|
||||||
|
# for name, group in data.groupby('track_id'):
|
||||||
|
# sampled_groups.append(group.iloc[::sample])
|
||||||
|
# print(f"Sampled {len(sampled_groups)} groups")
|
||||||
|
# data = pd.concat(sampled_groups, axis=1).T
|
||||||
|
print(f"Done sampling kept {data.shape[0]} items")
|
||||||
|
|
||||||
|
|
||||||
|
# String ot int
|
||||||
|
data['track_id'] = pd.to_numeric(data['track_id'], downcast='integer')
|
||||||
|
|
||||||
|
# redo diff after possible sampling:
|
||||||
|
data['diff'] = data.groupby(['track_id'])['frame_id'].diff()
|
||||||
|
# timestep to seconds
|
||||||
|
data['dt'] = data['diff'] * (1/fps)
|
||||||
|
|
||||||
|
# "Deriving displacement, velocity and accelation from x and y")
|
||||||
|
data['dx'] = data.groupby(['track_id'])['x'].diff()
|
||||||
|
data['dy'] = data.groupby(['track_id'])['y'].diff()
|
||||||
|
data['vx'] = data['dx'].div(data['dt'], axis=0)
|
||||||
|
data['vy'] = data['dy'].div(data['dt'], axis=0)
|
||||||
|
|
||||||
|
data['ax'] = data.groupby(['track_id'])['vx'].diff().div(data['dt'], axis=0)
|
||||||
|
data['ay'] = data.groupby(['track_id'])['vy'].diff().div(data['dt'], axis=0)
|
||||||
|
|
||||||
|
# then we need the velocity itself
|
||||||
|
data['v'] = np.sqrt(data['vx'].pow(2) + data['vy'].pow(2))
|
||||||
|
# and derive acceleration
|
||||||
|
data['a'] = data.groupby(['track_id'])['v'].diff().div(data['dt'], axis=0)
|
||||||
|
|
||||||
|
# we can calculate heading based on the velocity components
|
||||||
|
data['heading'] = (np.arctan2(data['vy'], data['vx']) * 180 / np.pi) % 360
|
||||||
|
|
||||||
|
# and derive it to get the rate of change of the heading
|
||||||
|
data['d_heading'] = data.groupby(['track_id'])['heading'].diff().div(data['dt'], axis=0)
|
||||||
|
|
||||||
|
# we can backfill the derived parameters (v and a), assuming they were constant when entering the frame
|
||||||
|
# so that our model can make estimations, based on these assumed values
|
||||||
|
group = data.groupby(['track_id'])
|
||||||
|
for field in ['dx', 'dy', 'vx', 'vy', 'ax', 'ay', 'v', 'a', 'heading', 'd_heading']:
|
||||||
|
data[field] = group[field].bfill()
|
||||||
|
|
||||||
|
data.set_index(['track_id', 'frame_id'], inplace=True) # use for quick access
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def filter_short_tracks(data: pd.DataFrame, n):
|
||||||
|
return data.groupby(['track_id']).filter(lambda group: len(group) >= n) # a lenght of 3 is neccessary to have all relevant derivatives of position
|
||||||
|
|
||||||
|
# print(filtered_data.shape[0], "items in filtered set, out of", data.shape[0], "in total set")
|
||||||
|
|
||||||
|
def normalise_position(data: pd.DataFrame):
|
||||||
|
mu = data[['x','y']].mean(axis=0)
|
||||||
|
std = data[['x','y']].std(axis=0)
|
||||||
|
|
||||||
|
data[['x_norm','y_norm']] = (data[['x','y']] - mu) / std
|
||||||
|
return data, mu, std
|
189
trap/tracker.py
189
trap/tracker.py
|
@ -11,10 +11,11 @@ import time
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import zmq
|
import zmq
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights, FasterRCNN_ResNet50_FPN_V2_Weights, fasterrcnn_resnet50_fpn_v2
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
from torchvision.models import ResNet50_Weights
|
from torchvision.models import ResNet50_Weights
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
@ -23,7 +24,7 @@ from ultralytics import YOLO
|
||||||
from ultralytics.engine.results import Results as YOLOResult
|
from ultralytics.engine.results import Results as YOLOResult
|
||||||
|
|
||||||
from trap.frame_emitter import DetectionState, Frame, Detection, Track
|
from trap.frame_emitter import DetectionState, Frame, Detection, Track
|
||||||
|
from bytetracker import BYTETracker
|
||||||
|
|
||||||
from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
|
from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
|
||||||
import tsmoothie.smoother
|
import tsmoothie.smoother
|
||||||
|
@ -45,17 +46,28 @@ DETECTOR_MASKRCNN = 'maskrcnn'
|
||||||
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||||
DETECTOR_YOLOv8 = 'ultralytics'
|
DETECTOR_YOLOv8 = 'ultralytics'
|
||||||
|
|
||||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
TRACKER_DEEPSORT = 'deepsort'
|
||||||
|
TRACKER_BYTETRACK = 'bytetrack'
|
||||||
|
|
||||||
|
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||||
|
TRACKERS =[TRACKER_DEEPSORT, TRACKER_BYTETRACK]
|
||||||
|
|
||||||
|
TRACKER_CONFIDENCE_MINIMUM = .2
|
||||||
|
TRACKER_BYTETRACK_MINIMUM = .1 # bytetrack can track items iwth lower thershold
|
||||||
|
NON_MAXIMUM_SUPRESSION = 1
|
||||||
|
RCNN_SCALE = .4 # seems to have no impact on detections in the corners
|
||||||
|
|
||||||
def _yolov8_track(frame: Frame, model: YOLO, **kwargs) -> List[Detection]:
|
def _yolov8_track(frame: Frame, model: YOLO, **kwargs) -> List[Detection]:
|
||||||
|
|
||||||
results: List[YOLOResult] = list(model.track(frame.img, persist=True, tracker="bytetrack.yaml", verbose=False, **kwargs))
|
results: List[YOLOResult] = list(model.track(frame.img, persist=True, tracker="custom_bytetrack.yaml", verbose=False, **kwargs))
|
||||||
if results[0].boxes is None or results[0].boxes.id is None:
|
if results[0].boxes is None or results[0].boxes.id is None:
|
||||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame.index) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
boxes = results[0].boxes.xywh.cpu()
|
||||||
|
track_ids = results[0].boxes.id.int().cpu().tolist()
|
||||||
|
classes = results[0].boxes.cls.int().cpu().tolist()
|
||||||
|
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame.index, class_id) for bbox, track_id, class_id in zip(boxes, track_ids, classes)]
|
||||||
|
|
||||||
class Multifile():
|
class Multifile():
|
||||||
def __init__(self, srcs: List[Path]):
|
def __init__(self, srcs: List[Path]):
|
||||||
|
@ -79,7 +91,7 @@ class Multifile():
|
||||||
|
|
||||||
|
|
||||||
class TrainingDataWriter:
|
class TrainingDataWriter:
|
||||||
def __init__(self, training_path = Optional[Path]):
|
def __init__(self, training_path: Optional[Path]):
|
||||||
if training_path is None:
|
if training_path is None:
|
||||||
self.path = None
|
self.path = None
|
||||||
return
|
return
|
||||||
|
@ -151,6 +163,9 @@ class TrainingDataWriter:
|
||||||
|
|
||||||
logger.info(f"Splitting gathered data from {sources.name}")
|
logger.info(f"Splitting gathered data from {sources.name}")
|
||||||
# for source_file in source_files:
|
# for source_file in source_files:
|
||||||
|
|
||||||
|
tracks_file = self.path / 'tracks.json'
|
||||||
|
tracks = defaultdict(lambda: [])
|
||||||
|
|
||||||
for name, line_nrs in lines.items():
|
for name, line_nrs in lines.items():
|
||||||
dir_path = self.path / name
|
dir_path = self.path / name
|
||||||
|
@ -178,25 +193,59 @@ class TrainingDataWriter:
|
||||||
|
|
||||||
parts[1] = str(track_id)
|
parts[1] = str(track_id)
|
||||||
target_fp.write("\t".join(parts))
|
target_fp.write("\t".join(parts))
|
||||||
|
tracks[track_id].append(parts)
|
||||||
|
|
||||||
|
with tracks_file.open('w') as fp:
|
||||||
|
json.dump(tracks, fp)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TrackerWrapper():
|
||||||
|
def __init__(self, tracker):
|
||||||
|
self.tracker = tracker
|
||||||
|
|
||||||
|
def track_detections():
|
||||||
|
raise RuntimeError("Not implemented")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_type(cls, tracker_type: str):
|
||||||
|
if tracker_type == TRACKER_BYTETRACK:
|
||||||
|
return ByteTrackWrapper(BYTETracker(track_thresh=TRACKER_BYTETRACK_MINIMUM, match_thresh=TRACKER_CONFIDENCE_MINIMUM, frame_rate=12)) # TODO)) Framerate from emitter
|
||||||
|
else:
|
||||||
|
return DeepSortWrapper(DeepSort(n_init=5, max_age=30, nms_max_overlap=NON_MAXIMUM_SUPRESSION,
|
||||||
|
embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||||
|
))
|
||||||
|
|
||||||
|
class DeepSortWrapper(TrackerWrapper):
|
||||||
|
def track_detections(self, detections, img: cv2.Mat, frame_idx: int):
|
||||||
|
detections = Tracker.detect_persons_deepsort_wrapper(detections)
|
||||||
|
tracks: List[DeepsortTrack] = self.tracker.update_tracks(detections, frame=img)
|
||||||
|
active_tracks = [t for t in tracks if t.is_confirmed()]
|
||||||
|
return [Detection.from_deepsort(t, frame_idx) for t in active_tracks]
|
||||||
|
# raw_detections, embeds=None, frame=None, today=None, others=None, instance_masks=Non
|
||||||
|
|
||||||
|
|
||||||
|
class ByteTrackWrapper(TrackerWrapper):
|
||||||
|
def __init__(self, tracker: BYTETracker):
|
||||||
|
self.tracker = tracker
|
||||||
|
|
||||||
|
def track_detections(self, detections: tuple[list[float,float,float,float], float, float], img: cv2.Mat, frame_idx: int):
|
||||||
|
# detections
|
||||||
|
if detections.shape[0] == 0:
|
||||||
|
detections = np.ndarray((0,0)) # needs to be 2-D
|
||||||
|
|
||||||
|
_ = self.tracker.update(detections)
|
||||||
|
active_tracks = [track for track in self.tracker.tracked_stracks if track.is_activated]
|
||||||
|
active_tracks = [track for track in active_tracks if track.start_frame < (self.tracker.frame_id - 5)]
|
||||||
|
return [Detection.from_bytetrack(track, frame_idx) for track in active_tracks]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, config: Namespace, is_running: Event):
|
def __init__(self, config: Namespace):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.is_running = is_running
|
|
||||||
|
|
||||||
context = zmq.Context()
|
|
||||||
self.frame_sock = context.socket(zmq.SUB)
|
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
|
||||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
|
||||||
self.frame_sock.connect(config.zmq_frame_addr)
|
|
||||||
|
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
|
||||||
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
|
||||||
|
|
||||||
|
|
||||||
# # TODO: config device
|
# # TODO: config device
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
@ -206,30 +255,39 @@ class Tracker:
|
||||||
|
|
||||||
logger.debug(f"Load tracker: {self.config.detector}")
|
logger.debug(f"Load tracker: {self.config.detector}")
|
||||||
|
|
||||||
|
conf = TRACKER_BYTETRACK_MINIMUM if self.config.tracker == TRACKER_BYTETRACK else TRACKER_CONFIDENCE_MINIMUM
|
||||||
if self.config.detector == DETECTOR_RETINANET:
|
if self.config.detector == DETECTOR_RETINANET:
|
||||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
weights = KeypointRCNN_ResNet50_FPN_Weights.COCO_V1
|
||||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=conf)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
|
self.mot_tracker = TrackerWrapper.init_type(self.config.tracker)
|
||||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
elif self.config.detector == DETECTOR_FASTERRCNN:
|
||||||
)
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
|
weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||||
|
self.model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=conf)
|
||||||
|
self.model.to(self.device)
|
||||||
|
# Put the model in inference mode
|
||||||
|
self.model.eval()
|
||||||
|
# Get the transforms for the model's weights
|
||||||
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
self.mot_tracker = TrackerWrapper.init_type(self.config.tracker)
|
||||||
elif self.config.detector == DETECTOR_MASKRCNN:
|
elif self.config.detector == DETECTOR_MASKRCNN:
|
||||||
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||||
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=conf) # if we use ByteTrack we can work with low probablity!
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
|
# self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.2, max_age=15, nms_max_overlap=0.9,
|
||||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
self.mot_tracker = TrackerWrapper.init_type(self.config.tracker)
|
||||||
)
|
|
||||||
elif self.config.detector == DETECTOR_YOLOv8:
|
elif self.config.detector == DETECTOR_YOLOv8:
|
||||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
else:
|
else:
|
||||||
|
@ -249,9 +307,40 @@ class Tracker:
|
||||||
|
|
||||||
|
|
||||||
logger.debug("Set up tracker")
|
logger.debug("Set up tracker")
|
||||||
|
|
||||||
|
def track_frame(self, frame: Frame):
|
||||||
|
if self.config.detector == DETECTOR_YOLOv8:
|
||||||
|
detections: List[Detection] = _yolov8_track(frame, self.model, classes=[0], imgsz=[1152, 640])
|
||||||
|
else :
|
||||||
|
detections: List[Detection] = self._resnet_track(frame, scale = RCNN_SCALE)
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
track = self.tracks[detection.track_id]
|
||||||
|
track.track_id = detection.track_id # for new tracks
|
||||||
|
|
||||||
|
track.history.append(detection) # add to history
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
def track(self):
|
def track(self, is_running: Event):
|
||||||
|
"""
|
||||||
|
Live tracking of frames coming in over zmq
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.frame_sock.connect(self.config.zmq_frame_addr)
|
||||||
|
|
||||||
|
self.trajectory_socket = context.socket(zmq.PUB)
|
||||||
|
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||||
|
self.trajectory_socket.bind(self.config.zmq_trajectory_addr)
|
||||||
|
|
||||||
prev_run_time = 0
|
prev_run_time = 0
|
||||||
|
|
||||||
# training_fp = None
|
# training_fp = None
|
||||||
|
@ -312,19 +401,16 @@ class Tracker:
|
||||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
|
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_YOLOv8:
|
detections: List[Detection] = self.track_frame(frame)
|
||||||
detections: [Detection] = _yolov8_track(frame, self.model, classes=[0], imgsz=[1152, 640])
|
|
||||||
else :
|
|
||||||
detections: [Detection] = self._resnet_track(frame.img, scale = 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Store detections into tracklets
|
# Store detections into tracklets
|
||||||
projected_coordinates = []
|
projected_coordinates = []
|
||||||
for detection in detections:
|
# now in track_frame()
|
||||||
track = self.tracks[detection.track_id]
|
# for detection in detections:
|
||||||
track.track_id = detection.track_id # for new tracks
|
# track = self.tracks[detection.track_id]
|
||||||
|
# track.track_id = detection.track_id # for new tracks
|
||||||
|
|
||||||
track.history.append(detection) # add to history
|
# track.history.append(detection) # add to history
|
||||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||||
|
|
||||||
# TODO: hadle occlusions, and dissappearance
|
# TODO: hadle occlusions, and dissappearance
|
||||||
|
@ -383,15 +469,17 @@ class Tracker:
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
def _resnet_track(self, img, scale: float = 1) -> [Detection]:
|
def _resnet_track(self, frame: Frame, scale: float = 1) -> List[Detection]:
|
||||||
|
img = frame.img
|
||||||
if scale != 1:
|
if scale != 1:
|
||||||
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
|
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
|
||||||
img = cv2.resize(img, dsize)
|
img = cv2.resize(img, dsize)
|
||||||
detections = self._resnet_detect_persons(img)
|
detections = self._resnet_detect_persons(img)
|
||||||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
tracks: List[Detection] = self.mot_tracker.track_detections(detections, img, frame.index)
|
||||||
return [Detection.from_deepsort(t).get_scaled(1/scale) for t in tracks]
|
# active_tracks = [t for t in tracks if t.is_confirmed()]
|
||||||
|
return [d.get_scaled(1/scale) for d in tracks]
|
||||||
|
|
||||||
def _resnet_detect_persons(self, frame) -> [Detection]:
|
def _resnet_detect_persons(self, frame) -> List[Detection]:
|
||||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||||
t = t.permute(2, 0, 1)
|
t = t.permute(2, 0, 1)
|
||||||
|
@ -406,8 +494,19 @@ class Tracker:
|
||||||
mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86])
|
mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86])
|
||||||
|
|
||||||
scores = prediction['scores'][mask]
|
scores = prediction['scores'][mask]
|
||||||
|
# print(scores, prediction['labels'])
|
||||||
labels = prediction['labels'][mask]
|
labels = prediction['labels'][mask]
|
||||||
boxes = prediction['boxes'][mask]
|
boxes = prediction['boxes'][mask]
|
||||||
|
# print(prediction['scores'])
|
||||||
|
|
||||||
|
if NON_MAXIMUM_SUPRESSION < 1:
|
||||||
|
nms_mask = torch.zeros(scores.shape[0]).bool()
|
||||||
|
nms_keep_ids = torchvision.ops.nms(boxes, scores, NON_MAXIMUM_SUPRESSION)
|
||||||
|
nms_mask[nms_keep_ids] = True
|
||||||
|
print(scores.shape[0], nms_keep_ids, nms_mask)
|
||||||
|
scores = scores[nms_mask]
|
||||||
|
labels = labels[nms_mask]
|
||||||
|
boxes = boxes[nms_mask]
|
||||||
|
|
||||||
# TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
|
# TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
|
||||||
# (which I _think_ we better do after filtering)
|
# (which I _think_ we better do after filtering)
|
||||||
|
@ -415,7 +514,7 @@ class Tracker:
|
||||||
|
|
||||||
# dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...]
|
# dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...]
|
||||||
detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
|
detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
|
||||||
detections = self.detect_persons_deepsort_wrapper(detections)
|
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
@ -429,8 +528,8 @@ class Tracker:
|
||||||
|
|
||||||
|
|
||||||
def run_tracker(config: Namespace, is_running: Event):
|
def run_tracker(config: Namespace, is_running: Event):
|
||||||
router = Tracker(config, is_running)
|
router = Tracker(config)
|
||||||
router.track()
|
router.track(is_running)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -467,7 +566,7 @@ class Smoother:
|
||||||
ws = self.smoother.smooth_data[0]
|
ws = self.smoother.smooth_data[0]
|
||||||
self.smoother.smooth(hs)
|
self.smoother.smooth(hs)
|
||||||
hs = self.smoother.smooth_data[0]
|
hs = self.smoother.smooth_data[0]
|
||||||
new_history = [Detection(d.track_id, l, t, w, h, d.conf, d.state, d.frame_nr) for l, t, w, h, d in zip(ls,ts,ws,hs, track.history)]
|
new_history = [Detection(d.track_id, l, t, w, h, d.conf, d.state, d.frame_nr, d.det_class) for l, t, w, h, d in zip(ls,ts,ws,hs, track.history)]
|
||||||
new_track = Track(track.track_id, new_history, track.predictor_history, track.predictions)
|
new_track = Track(track.track_id, new_history, track.predictor_history, track.predictions)
|
||||||
new_tracks.append(new_track)
|
new_tracks.append(new_track)
|
||||||
frame.tracks = {t.track_id: t for t in new_tracks}
|
frame.tracks = {t.track_id: t for t in new_tracks}
|
||||||
|
|
Loading…
Reference in a new issue