Rudimentary tracker results

This commit is contained in:
Ruben van de Ven 2023-10-12 22:35:08 +02:00
parent 7c06913d88
commit 23162da767
8 changed files with 163 additions and 53 deletions

112
poetry.lock generated
View file

@ -497,6 +497,22 @@ files = [
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
] ]
[[package]]
name = "deep-sort-realtime"
version = "1.3.2"
description = "A more realtime adaptation of Deep SORT"
optional = false
python-versions = "*"
files = [
{file = "deep-sort-realtime-1.3.2.tar.gz", hash = "sha256:32bab92f981a274fce3ff121f35894e5adab7ca00314c113c348de7bcb82d73e"},
{file = "deep_sort_realtime-1.3.2-py3-none-any.whl", hash = "sha256:a6e144c888fdfb27245d2a060acbe0d2f3088448defbc419f7a26bce063bdd6c"},
]
[package.dependencies]
numpy = "*"
opencv-python = "*"
scipy = "*"
[[package]] [[package]]
name = "defusedxml" name = "defusedxml"
version = "0.7.1" version = "0.7.1"
@ -2934,38 +2950,6 @@ files = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
] ]
[[package]]
name = "torch"
version = "1.12.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
{file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},
{file = "torch-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:e9c8f4a311ac29fc7e8e955cfb7733deb5dbe1bdaabf5d4af2765695824b7e0d"},
{file = "torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:976c3f997cea38ee91a0dd3c3a42322785414748d1761ef926b789dfa97c6134"},
{file = "torch-1.12.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:68104e4715a55c4bb29a85c6a8d57d820e0757da363be1ba680fa8cc5be17b52"},
{file = "torch-1.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:743784ccea0dc8f2a3fe6a536bec8c4763bd82c1352f314937cb4008d4805de1"},
{file = "torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b5dbcca369800ce99ba7ae6dee3466607a66958afca3b740690d88168752abcf"},
{file = "torch-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a"},
{file = "torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:8a34a2fbbaa07c921e1b203f59d3d6e00ed379f2b384445773bd14e328a5b6c8"},
{file = "torch-1.12.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:42f639501928caabb9d1d55ddd17f07cd694de146686c24489ab8c615c2871f2"},
{file = "torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0b44601ec56f7dd44ad8afc00846051162ef9c26a8579dda0a02194327f2d55e"},
{file = "torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cd26d8c5640c3a28c526d41ccdca14cf1cbca0d0f2e14e8263a7ac17194ab1d2"},
{file = "torch-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:42e115dab26f60c29e298559dbec88444175528b729ae994ec4c65d56fe267dd"},
{file = "torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:a8320ba9ad87e80ca5a6a016e46ada4d1ba0c54626e135d99b2129a4541c509d"},
{file = "torch-1.12.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:03e31c37711db2cd201e02de5826de875529e45a55631d317aadce2f1ed45aa8"},
{file = "torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9b356aea223772cd754edb4d9ecf2a025909b8615a7668ac7d5130f86e7ec421"},
{file = "torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cf6f54b43c0c30335428195589bd00e764a6d27f3b9ba637aaa8c11aaf93073"},
{file = "torch-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d"},
{file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
{file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
]
[package.dependencies]
typing-extensions = "*"
[[package]] [[package]]
name = "torch" name = "torch"
version = "1.12.1+cu113" version = "1.12.1+cu113"
@ -2983,6 +2967,68 @@ typing-extensions = "*"
type = "url" type = "url"
url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl" url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl"
[[package]]
name = "torchvision"
version = "0.13.1"
description = "image and video datasets and models for torch deep learning"
optional = false
python-versions = ">=3.7"
files = [
{file = "torchvision-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:19286a733c69dcbd417b86793df807bd227db5786ed787c17297741a9b0d0fc7"},
{file = "torchvision-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:08f592ea61836ebeceb5c97f4d7a813b9d7dc651bbf7ce4401563ccfae6a21fc"},
{file = "torchvision-0.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:ef5fe3ec1848123cd0ec74c07658192b3147dcd38e507308c790d5943e87b88c"},
{file = "torchvision-0.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:099874088df104d54d8008f2a28539ca0117b512daed8bf3c2bbfa2b7ccb187a"},
{file = "torchvision-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:8e4d02e4d8a203e0c09c10dfb478214c224d080d31efc0dbf36d9c4051f7f3c6"},
{file = "torchvision-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5e631241bee3661de64f83616656224af2e3512eb2580da7c08e08b8c965a8ac"},
{file = "torchvision-0.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:899eec0b9f3b99b96d6f85b9aa58c002db41c672437677b553015b9135b3be7e"},
{file = "torchvision-0.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:83e9e2457f23110fd53b0177e1bc621518d6ea2108f570e853b768ce36b7c679"},
{file = "torchvision-0.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7552e80fa222252b8b217a951c85e172a710ea4cad0ae0c06fbb67addece7871"},
{file = "torchvision-0.13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f230a1a40ed70d51e463ce43df243ec520902f8725de2502e485efc5eea9d864"},
{file = "torchvision-0.13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e9a563894f9fa40692e24d1aa58c3ef040450017cfed3598ff9637f404f3fe3b"},
{file = "torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7cb789ceefe6dcd0dc8eeda37bfc45efb7cf34770eac9533861d51ca508eb5b3"},
{file = "torchvision-0.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:87c137f343197769a51333076e66bfcd576301d2cd8614b06657187c71b06c4f"},
{file = "torchvision-0.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:4d8bf321c4380854ef04613935fdd415dce29d1088a7ff99e06e113f0efe9203"},
{file = "torchvision-0.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0298bae3b09ac361866088434008d82b99d6458fe8888c8df90720ef4b347d44"},
{file = "torchvision-0.13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c5ed609c8bc88c575226400b2232e0309094477c82af38952e0373edef0003fd"},
{file = "torchvision-0.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3567fb3def829229ec217c1e38f08c5128ff7fb65854cac17ebac358ff7aa309"},
{file = "torchvision-0.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b167934a5943242da7b1e59318f911d2d253feeca0d13ad5d832b58eed943401"},
{file = "torchvision-0.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0e77706cc90462653620e336bb90daf03d7bf1b88c3a9a3037df8d111823a56e"},
]
[package.dependencies]
numpy = "*"
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
requests = "*"
torch = "1.12.1"
typing-extensions = "*"
[package.extras]
scipy = ["scipy"]
[[package]]
name = "torchvision"
version = "0.13.1+cu113"
description = "image and video datasets and models for torch deep learning"
optional = false
python-versions = ">=3.7"
files = [
{file = "torchvision-0.13.1+cu113-cp310-cp310-linux_x86_64.whl", hash = "sha256:b471090622d074d19fadacbc91ec02ddc1972d2fa4ada7f1c8a087bca3a5bcb0"},
]
[package.dependencies]
numpy = "*"
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
requests = "*"
torch = "1.12.1"
typing-extensions = "*"
[package.extras]
scipy = ["scipy"]
[package.source]
type = "url"
url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "tornado" name = "tornado"
version = "6.3.3" version = "6.3.3"
@ -3226,4 +3272,4 @@ test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10,<3.12," python-versions = "^3.10,<3.12,"
content-hash = "839e0f10b4ecc9e489cb1b5aface622a5e129b69650546de2e04f87f1654360d" content-hash = "06852d16ba90b0438b84e1c51c084c6aa6a5c7a78a6af1e17b7d59e5304ed5a7"

View file

@ -9,6 +9,13 @@ readme = "README.md"
python = "^3.10,<3.12," python = "^3.10,<3.12,"
trajectron-plus-plus = { path = "../Trajectron-plus-plus/", develop = true } trajectron-plus-plus = { path = "../Trajectron-plus-plus/", develop = true }
torchvision = [
{ version="0.13.1" },
# { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp38-cp38-linux_x86_64.whl", markers = "python_version ~= '3.8' and sys_platform == 'linux'" },
{ url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp310-cp310-linux_x86_64.whl", markers = "python_version ~= '3.10' and sys_platform == 'linux'" },
]
deep-sort-realtime = "^1.3.2"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View file

@ -23,9 +23,10 @@ parser.add_argument(
) )
# parser.add_argument('--foo') # parser.add_argument('--foo')
inference_parser = parser.add_argument_group('inference server') inference_parser = parser.add_argument_group('Inference')
connection_parser = parser.add_argument_group('connection') connection_parser = parser.add_argument_group('Connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter') frame_emitter_parser = parser.add_argument_group('Frame emitter')
tracker_parser = parser.add_argument_group('Tracker')
inference_parser.add_argument("--model_dir", inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference", help="directory with the model to use for inference",
@ -164,5 +165,12 @@ frame_emitter_parser.add_argument("--video-src",
help="source video to track from", help="source video to track from",
type=Path, type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4') default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4')
#TODO: camera as source
#TODO: camera
# Tracker
tracker_parser.add_argument("--homography",
help="File with homography params",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')

View file

@ -1,9 +1,12 @@
from argparse import Namespace from argparse import Namespace
import logging
import pickle
import time import time
import cv2 import cv2
import zmq import zmq
logger = logging.getLogger('trap.frame_emitter')
class FrameEmitter: class FrameEmitter:
''' '''
@ -17,6 +20,7 @@ class FrameEmitter:
self.frame_sock = context.socket(zmq.PUB) self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.bind(config.zmq_frame_addr) self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
logger.info(f"Connection socket {config.zmq_frame_addr}")
def emit_video(self): def emit_video(self):
video = cv2.VideoCapture(str(self.config.video_src)) video = cv2.VideoCapture(str(self.config.video_src))
@ -33,7 +37,9 @@ class FrameEmitter:
ret, frame = video.read() ret, frame = video.read()
assert ret is not False # not really error proof... assert ret is not False # not really error proof...
self.frame_sock.send(frame) # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue?
self.frame_sock.send(pickle.dumps(frame))
# defer next loop # defer next loop
new_frame_time = time.time() new_frame_time = time.time()

View file

@ -5,6 +5,7 @@ from trap.config import parser
from trap.frame_emitter import run_frame_emitter from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import InferenceServer, run_inference_server from trap.prediction_server import InferenceServer, run_inference_server
from trap.socket_forwarder import run_ws_forwarder from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker
logger = logging.getLogger("trap.plumbing") logger = logging.getLogger("trap.plumbing")
@ -25,12 +26,13 @@ def start():
# instantiating process with arguments # instantiating process with arguments
procs = [ procs = [
# Process(target=run_ws_forwarder, args=(args,)), Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
Process(target=run_frame_emitter, args=(args,)), Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
Process(target=run_tracker, args=(args,), name='tracker'),
] ]
if not args.bypass_prediction: if not args.bypass_prediction:
procs.append( procs.append(
Process(target=run_inference_server, args=(args,)), Process(target=run_inference_server, args=(args,), name='inference'),
) )
logger.info("start") logger.info("start")

View file

@ -21,7 +21,7 @@ import matplotlib.pyplot as plt
import zmq import zmq
logger = logging.getLogger("trap.inference") logger = logging.getLogger("trap.prediction")
# if not torch.cuda.is_available() or self.config.device == 'cpu': # if not torch.cuda.is_available() or self.config.device == 'cpu':
@ -117,7 +117,7 @@ class InferenceServer:
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr) self.prediction_socket.bind(config.zmq_prediction_addr)
print(self.prediction_socket) # print(self.prediction_socket)
def run(self): def run(self):
@ -222,7 +222,7 @@ class InferenceServer:
data = self.trajectory_socket.recv_string() data = self.trajectory_socket.recv_string()
trajectory_data = json.loads(data) trajectory_data = json.loads(data)
logger.info(f"Receive {trajectory_data}") logger.debug(f"Receive {trajectory_data}")
# class FakeNode: # class FakeNode:
# def __init__(self, node_type: NodeType): # def __init__(self, node_type: NodeType):
@ -326,7 +326,7 @@ class InferenceServer:
futures_dict = futures_dict[ts_key] futures_dict = futures_dict[ts_key]
response = {} response = {}
print(histories_dict) logger.debug(f"{histories_dict=}")
for node in histories_dict: for node in histories_dict:
history = histories_dict[node] history = histories_dict[node]

View file

@ -1,4 +1,8 @@
from argparse import Namespace from argparse import Namespace
import json
import logging
import pickle
import time
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
@ -6,19 +10,28 @@ import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track
Detection = [int, int, int, int, float, int] Detection = [int, int, int, int, float, int]
Detections = [Detection] Detections = [Detection]
logger = logging.getLogger("trap.tracker")
class Tracker: class Tracker:
def __init__(self, config: Namespace): def __init__(self, config: Namespace):
self.config = config
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB) self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.bind(config.zmq_frame_addr) self.frame_sock.connect(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# TODO: config device self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
@ -29,17 +42,44 @@ class Tracker:
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
# homography = list(source.glob('*img2world.txt'))[0]
self.H = np.loadtxt(self.config.homography, delimiter=',')
self.mot_tracker = DeepSort(max_age=5) self.mot_tracker = DeepSort(max_age=5)
logger.debug("Set up tracker")
def track(self): def track(self):
while True: while True:
frame = self.frame_sock.recv() frame = pickle.loads(self.frame_sock.recv())
detections = self.detect_persons(frame) detections = self.detect_persons(frame)
tracks = self.mot_tracker.update_tracks(detections, frame=frame) tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
TEMP_boxes = [t.to_ltwh() for t in tracks]
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
if len(TEMP_coords):
TEMP_proj_coords = cv2.perspectiveTransform(TEMP_coords,self.H)
else:
TEMP_proj_coords = []
# print(TEMP_proj_coords)
trajectories = {}
for i, coords in enumerate(TEMP_proj_coords):
tid = tracks[i].track_id
trajectories[tid] = {
"id": tid,
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
}
logger.debug(f"{trajectories}")
self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker) # TODO: provide a track object that actually keeps history (unlike tracker)
#TODO calculate fps (also for other loops to see asynchonity)
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
def detect_persons(self, frame) -> Detections: def detect_persons(self, frame) -> Detections:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
@ -70,7 +110,7 @@ class Tracker:
return detections return detections
@classmethod @classmethod
def detect_persons_deepsort_wrapper(detections): def detect_persons_deepsort_wrapper(cls, detections):
"""make detect_persons() compatible with """make detect_persons() compatible with
deep_sort_realtime tracker by going from ltrb to ltwh and deep_sort_realtime tracker by going from ltrb to ltwh and
different nesting different nesting

View file

@ -30,7 +30,7 @@
<script> <script>
// map the field to coordinates of our dummy tracker // map the field to coordinates of our dummy tracker
const field_range = { x: [-10, 10], y: [-10, 10] } const field_range = { x: [-30, 10], y: [-10, 10] }
// Create WebSocket connection. // Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`); const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
@ -111,8 +111,9 @@
} }
} }
} }
console.log(tracker)
trajectory_socket.send(JSON.stringify(tracker)) // TODO: make variable in template
//trajectory_socket.send(JSON.stringify(tracker))
setTimeout(appendAndSendPositions, 200) setTimeout(appendAndSendPositions, 200)
} }