Run as separate processes for simple auto restart of individual nodes
This commit is contained in:
parent
9cc08f09af
commit
907d5a2157
14 changed files with 612 additions and 536 deletions
|
@ -21,6 +21,7 @@ These are roughly the steps to go from datagathering to training
|
|||
* `uv run process_data --src-dir EXPERIMENTS/raw/NAME --dst-dir EXPERIMENTS/trajectron-data/ --name NAME --smooth-tracks --camera-fps 12 --homography ../DATASETS/NAME/homography.json --calibration ../DATASETS/NAME/calibration.json --filter-displacement 2 --map-img-path ../DATASETS/NAME/map.png`
|
||||
5. Train Trajectron model `uv run trajectron_train --eval_every 10 --vis_every 1 --train_data_dict NAME_train.pkl --eval_data_dict NAME_val.pkl --offline_scene_graph no --preprocess_workers 8 --log_dir EXPERIMENTS/models --log_tag _NAME --train_epochs 100 --conf EXPERIMENTS/config.json --batch_size 256 --data_dir EXPERIMENTS/trajectron-data `
|
||||
6. The run!
|
||||
* On a video file (you can use a wildcard) `DISPLAY=:1 uv run trapserv --remote-log-addr 100.69.123.91 --eval_device cuda:0 --detector ultralytics --homography ../DATASETS/NAME/homography.json --eval_data_dict EXPERIMENTS/trajectron-data/hof2s-m_test.pkl --video-src ../DATASETS/NAME/*.mp4 --model_dir EXPERIMENTS/models/models_DATE_NAME/--smooth-predictions --smooth-tracks --num-samples 3 --render-window --calibration ../DATASETS/NAME/calibration.json` (the DISPLAY environment variable is used here to running over SSH connection and display on local monitor)
|
||||
* `uv run supervisord`
|
||||
<!-- * On a video file (you can use a wildcard) `DISPLAY=:1 uv run trapserv --remote-log-addr 100.69.123.91 --eval_device cuda:0 --detector ultralytics --homography ../DATASETS/NAME/homography.json --eval_data_dict EXPERIMENTS/trajectron-data/hof2s-m_test.pkl --video-src ../DATASETS/NAME/*.mp4 --model_dir EXPERIMENTS/models/models_DATE_NAME/--smooth-predictions --smooth-tracks --num-samples 3 --render-window --calibration ../DATASETS/NAME/calibration.json` (the DISPLAY environment variable is used here to running over SSH connection and display on local monitor)
|
||||
* or on the RTSP stream. Which uses gstreamer to substantially reduce latency compared to the default ffmpeg bindings in OpenCV.
|
||||
* To just have a single trajectory pulled from distribution use `--full-dist`. Also try `--z_mode`.
|
||||
* To just have a single trajectory pulled from distribution use `--full-dist`. Also try `--z_mode`. -->
|
||||
|
|
|
@ -33,6 +33,7 @@ dependencies = [
|
|||
"python-statemachine>=2.5.0",
|
||||
"facenet-pytorch>=2.5.3",
|
||||
"simplification>=0.7.12",
|
||||
"supervisor>=4.2.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
@ -43,8 +44,13 @@ compare = "trap.tools:tracker_compare"
|
|||
process_data = "trap.process_data:main"
|
||||
blacklist = "trap.tools:blacklist_tracks"
|
||||
rewrite_tracks = "trap.tools:rewrite_raw_track_files"
|
||||
live_video_source = "trap.frame_emitter:run"
|
||||
live_tracker = "trap.tracker:run"
|
||||
|
||||
trap_video_source = "trap.frame_emitter:FrameEmitter.parse_and_start"
|
||||
trap_tracker = "trap.tracker:Tracker.parse_and_start"
|
||||
trap_stage = "trap.stage:Stage.parse_and_start"
|
||||
trap_prediction = "trap.prediction_server:PredictionServer.parse_and_start"
|
||||
trap_render_cv = "trap.cv_renderer:CvRenderer.parse_and_start"
|
||||
trap_monitor = "trap.monitor:Monitor.parse_and_start" # migrate timer
|
||||
|
||||
[tool.uv]
|
||||
|
||||
|
|
48
supervisord.conf
Normal file
48
supervisord.conf
Normal file
|
@ -0,0 +1,48 @@
|
|||
[inet_http_server]
|
||||
port = *:8293
|
||||
# username = user
|
||||
# password = 123
|
||||
|
||||
[supervisord]
|
||||
nodaemon = True
|
||||
|
||||
|
||||
; The rpcinterface:supervisor section must remain in the config file for
|
||||
; RPC (supervisorctl/web interface) to work. Additional interfaces may be
|
||||
; added by defining them in separate [rpcinterface:x] sections.
|
||||
[rpcinterface:supervisor]
|
||||
supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
|
||||
|
||||
[supervisorctl]
|
||||
serverurl = http://localhost:8293
|
||||
|
||||
[program:monitor]
|
||||
command=uv run trap_monitor
|
||||
numprocs=1
|
||||
directory=%(here)s
|
||||
|
||||
[program:video]
|
||||
command=uv run trap_video_source --homography ../DATASETS/hof3/homography.json --video-src ../DATASETS/hof3/hof3-cam-demo-twoperson.mp4 --calibration ../DATASETS/hof3/calibration.json --video-loop
|
||||
# command=uv run trap_video_source --homography ../DATASETS/hof3-cam-baumer/homography.json --video-src gige://../DATASETS/hof3-cam-baumer/gige_config.json --calibration ../DATASETS/hof3-cam-baumer/calibration.json
|
||||
directory=%(here)s
|
||||
directory=%(here)s
|
||||
|
||||
[program:tracker]
|
||||
command=uv run trap_tracker
|
||||
directory=%(here)s
|
||||
|
||||
[program:stage]
|
||||
command=uv run trap_stage
|
||||
directory=%(here)s
|
||||
|
||||
[program:predictor]
|
||||
command=uv run trap_prediction --eval_device cuda:0 --model_dir EXPERIMENTS/models/models_20241229_21_35_13_hof3-m2-ud-split-conv12-f2.0-map-2024-12-29/ --num-samples 1 --map_encoding --eval_data_dict EXPERIMENTS/trajectron-data/hof3-m2-ud-split-nostep-conv12-f2.0-map-2024-12-29_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode
|
||||
directory=%(here)s
|
||||
|
||||
[program:render_cv]
|
||||
command=uv run trap_render_cv
|
||||
directory=%(here)s
|
||||
environment=DISPLAY=":0"
|
||||
autostart=false
|
||||
; can be long to quit if rendering to video file
|
||||
stopwaitsecs=60
|
17
trap/base.py
17
trap/base.py
|
@ -2,12 +2,14 @@ from __future__ import annotations
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from enum import IntFlag
|
||||
from itertools import cycle
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
import types
|
||||
from typing import Iterable, Optional, Tuple, Union, List
|
||||
import cv2
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -617,6 +619,7 @@ class Frame:
|
|||
H: Optional[np.array] = None
|
||||
camera: Optional[Camera] = None
|
||||
maps: Optional[List[cv2.Mat]] = None
|
||||
log: dict = field(default_factory=lambda: {}) # settings used during processing. All intermediate nodes can store their config here
|
||||
|
||||
def aslist(self) -> List[dict]:
|
||||
return { t.track_id:
|
||||
|
@ -724,3 +727,17 @@ class CameraAction(argparse.Action):
|
|||
# camera = Camera(np.array(data['camera_matrix']), np.array(data['dist_coeff']), data['dim']['width'], data['dim']['height'], namespace.H, namespace.camera_fps)
|
||||
|
||||
setattr(namespace, 'camera', camera)
|
||||
|
||||
class LambdaParser(argparse.ArgumentParser):
|
||||
"""Execute lambda functions
|
||||
"""
|
||||
def parse_args(self, args=None, namespace=None):
|
||||
args = super().parse_args(args, namespace)
|
||||
|
||||
for key in vars(args):
|
||||
f = args.__dict__[key]
|
||||
if type(f) == types.LambdaType:
|
||||
print(f'Getting default value for {key}')
|
||||
args.__dict__[key] = f()
|
||||
|
||||
return args
|
|
@ -6,23 +6,10 @@ import json
|
|||
|
||||
from trap.tracker import DETECTORS, TRACKER_BYTETRACK, TRACKERS
|
||||
from trap.frame_emitter import Camera
|
||||
from trap.base import CameraAction, HomographyAction
|
||||
from trap.base import CameraAction, HomographyAction, LambdaParser
|
||||
from pyparsing import Optional
|
||||
from trap.frame_emitter import UrlOrPath
|
||||
|
||||
class LambdaParser(argparse.ArgumentParser):
|
||||
"""Execute lambda functions
|
||||
"""
|
||||
def parse_args(self, args=None, namespace=None):
|
||||
args = super().parse_args(args, namespace)
|
||||
|
||||
for key in vars(args):
|
||||
f = args.__dict__[key]
|
||||
if type(f) == types.LambdaType:
|
||||
print(f'Getting default value for {key}')
|
||||
args.__dict__[key] = f()
|
||||
|
||||
return args
|
||||
|
||||
parser = LambdaParser()
|
||||
# parser.parse_args()
|
||||
|
@ -261,7 +248,6 @@ frame_emitter_parser.add_argument("--video-loop",
|
|||
#TODO: camera as source
|
||||
|
||||
|
||||
# Tracker
|
||||
|
||||
tracker_parser.add_argument("--camera-fps",
|
||||
help="Camera FPS",
|
||||
|
@ -277,6 +263,8 @@ tracker_parser.add_argument("--calibration",
|
|||
# type=Path,
|
||||
default=None,
|
||||
action=CameraAction)
|
||||
|
||||
# Tracker
|
||||
tracker_parser.add_argument("--save-for-training",
|
||||
help="Specify the path in which to save",
|
||||
type=Path,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import collections
|
||||
from gc import is_finalized
|
||||
import logging
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
from typing import MutableSequence
|
||||
import zmq
|
||||
|
@ -8,20 +10,59 @@ import zmq
|
|||
logger = logging.getLogger('counter')
|
||||
|
||||
class CounterSender:
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters2"):
|
||||
# self.name = name
|
||||
self.context = zmq.Context()
|
||||
self.sock = self.context.socket(zmq.PUB)
|
||||
# self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
# self.sock.sndhwm = 1
|
||||
self.sock.bind(address)
|
||||
self.sock.connect(address)
|
||||
|
||||
def set(self, name:str, value:float):
|
||||
try:
|
||||
self.sock.send_multipart([name.encode('utf8'), str(value).encode("utf8")], flags=zmq.NOBLOCK)
|
||||
# we cannot use send_multipart in combination with conflate
|
||||
self.sock.send_pyobj([name, value], flags=zmq.NOBLOCK)
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(f"No space in que to count {name} as {value}")
|
||||
|
||||
class CounterFpsSender():
|
||||
def __init__(self, name:str , sender: CounterSender):
|
||||
self.name = name
|
||||
self.sender = sender
|
||||
self.tocs: MutableSequence[(float, int)] = collections.deque(maxlen=5)
|
||||
self.iterations: int = 0
|
||||
# threading.Event.wait()
|
||||
# TODO thread to daeomic loop so it automatically stops
|
||||
self.thread = threading.Thread(target=self.interval, daemon=True)
|
||||
self.is_finished = threading.Event()
|
||||
|
||||
def tick(self):
|
||||
self.iterations += 1
|
||||
self.snapshot()
|
||||
|
||||
def snapshot(self):
|
||||
self.tocs.append((time.perf_counter(), self.iterations))
|
||||
self.sender.set(self.name, self.fps)
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
if len(self.tocs) < 2:
|
||||
return 0
|
||||
dt = self.tocs[-1][0] - self.tocs[0][0]
|
||||
di = self.tocs[-1][1] - self.tocs[0][1]
|
||||
return di/dt
|
||||
|
||||
def interval(self):
|
||||
while True:
|
||||
self.is_finished.wait(.5)
|
||||
if self.is_finished.is_set():
|
||||
break
|
||||
|
||||
self.snapshot()
|
||||
# timer = threading.Timer(.5, self.interval)
|
||||
# timer.start()
|
||||
|
||||
|
||||
class CounterLog():
|
||||
def __init__(self, history = 20):
|
||||
self.history: MutableSequence[(float, float)] = collections.deque(maxlen=history)
|
||||
|
@ -30,29 +71,37 @@ class CounterLog():
|
|||
self.history.append((time.perf_counter(), value))
|
||||
|
||||
def value(self):
|
||||
if not len(self.history):
|
||||
return None
|
||||
return self.history[-1][1]
|
||||
|
||||
def has_value(self):
|
||||
if not len(self.history):
|
||||
return False
|
||||
if (time.perf_counter() - self.history[-1][0]) > 4:
|
||||
# no update in 4s: very slow. Dead thread?
|
||||
return False
|
||||
return True
|
||||
|
||||
def avg(self):
|
||||
if not len(self.history):
|
||||
return 0.
|
||||
return statistics.fmean([h[1] for h in self.history])
|
||||
|
||||
class CounterListerner():
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters"):
|
||||
def __init__(self, address = "ipc:///tmp/trap-counters2"):
|
||||
self.context = zmq.Context()
|
||||
self.sock = self.context.socket(zmq.SUB)
|
||||
self.sock.connect(address)
|
||||
self.sock.bind(address)
|
||||
self.sock.subscribe( b'')
|
||||
self.values: collections.defaultdict[str, CounterLog] = collections.defaultdict(lambda: CounterLog())
|
||||
|
||||
def snapshot(self):
|
||||
messages = []
|
||||
while self.sock.poll(0) == zmq.POLLIN:
|
||||
name, value = self.sock.recv_multipart()
|
||||
name, value = name.decode('utf8'),float(value.decode('utf8'))
|
||||
self.values[name].add(value)
|
||||
msg = self.sock.recv_pyobj()
|
||||
# print(msg)
|
||||
name, value = msg
|
||||
# name, value = name.decode('utf8'),float(value.decode('utf8'))
|
||||
self.values[name].add(float(value))
|
||||
|
||||
|
||||
def get_latest(self):
|
||||
|
@ -60,7 +109,7 @@ class CounterListerner():
|
|||
return self.values
|
||||
|
||||
def to_string(self):
|
||||
strs = [f"{k}: {v.value()} ({v.avg()})" for (k,v) in self.values.items()]
|
||||
strs = [(f"{k}: {v.value():.2f} ({v.avg():.2f})" if v.has_value() else f"{k}: --") for (k,v) in self.values.items()]
|
||||
return " ".join(strs)
|
||||
|
||||
|
||||
|
|
|
@ -1,74 +1,44 @@
|
|||
# used for "Forward Referencing of type annotations"
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import ffmpeg
|
||||
from argparse import Namespace
|
||||
import datetime
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from multiprocessing.synchronize import Event as BaseEvent
|
||||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import json
|
||||
import pyglet
|
||||
import pyglet.event
|
||||
import zmq
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import math
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
|
||||
from pyglet import shapes
|
||||
from PIL import Image
|
||||
|
||||
from trap.counter import CounterListerner
|
||||
from trap.frame_emitter import DetectionState, Frame, Track, Camera
|
||||
from trap.frame_emitter import Frame, Track
|
||||
from trap.node import Node
|
||||
from trap.preview_renderer import FrameWriter
|
||||
from trap.tools import draw_track, draw_track_predictions, draw_track_projected, draw_trackjectron_history, to_point
|
||||
from trap.utils import convert_world_points_to_img_points, convert_world_space_to_img_space
|
||||
|
||||
|
||||
from trap.tools import draw_track_predictions, draw_track_projected, to_point
|
||||
from trap.utils import convert_world_points_to_img_points
|
||||
|
||||
logger = logging.getLogger("trap.simple_renderer")
|
||||
|
||||
class CvRenderer:
|
||||
def __init__(self, config: Namespace, is_running: BaseEvent):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
class CvRenderer(Node):
|
||||
def setup(self):
|
||||
self.prediction_sock = self.sub(self.config.zmq_prediction_addr)
|
||||
self.tracker_sock = self.sub(self.config.zmq_trajectory_addr)
|
||||
self.frame_sock = self.sub(self.config.zmq_frame_addr)
|
||||
|
||||
self.counter_listener = CounterListerner()
|
||||
|
||||
context = zmq.Context()
|
||||
self.prediction_sock = context.socket(zmq.SUB)
|
||||
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
# self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
||||
self.prediction_sock.connect(config.zmq_prediction_addr)
|
||||
|
||||
self.tracker_sock = context.socket(zmq.SUB)
|
||||
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.tracker_sock.connect(config.zmq_trajectory_addr)
|
||||
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.connect(config.zmq_frame_addr)
|
||||
|
||||
|
||||
self.H = self.config.H
|
||||
|
||||
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
# self.H = self.config.H
|
||||
# self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
self.fps = 60
|
||||
self.frame_size = (self.config.camera.projected_w,self.config.camera.projected_h)
|
||||
self.hide_stats = False
|
||||
self.frame_size = None # configure on first frame recv
|
||||
# self.frame_size = (self.config.camera.projected_w,self.config.camera.projected_h)
|
||||
|
||||
self.out_writer = self.start_writer() if self.config.render_file else None
|
||||
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
||||
|
||||
|
@ -80,81 +50,6 @@ class CvRenderer:
|
|||
self.tracks: Dict[str, Track] = {}
|
||||
self.predictions: Dict[str, Track] = {}
|
||||
|
||||
|
||||
# self.init_shapes()
|
||||
|
||||
# self.init_labels()
|
||||
|
||||
|
||||
def init_shapes(self):
|
||||
'''
|
||||
Due to error when running headless, we need to configure options before extending the shapes class
|
||||
'''
|
||||
class GradientLine(shapes.Line):
|
||||
def __init__(self, x, y, x2, y2, width=1, color1=[255,255,255], color2=[255,255,255], batch=None, group=None):
|
||||
# print('colors!', colors)
|
||||
# assert len(colors) == 6
|
||||
|
||||
r, g, b, *a = color1
|
||||
self._rgba1 = (r, g, b, a[0] if a else 255)
|
||||
r, g, b, *a = color2
|
||||
self._rgba2 = (r, g, b, a[0] if a else 255)
|
||||
|
||||
# print('rgba', self._rgba)
|
||||
|
||||
super().__init__(x, y, x2, y2, width, color1, batch=None, group=None)
|
||||
# <pyglet.graphics.vertexdomain.VertexList
|
||||
# pyglet.graphics.vertexdomain
|
||||
# print(self._vertex_list)
|
||||
|
||||
def _create_vertex_list(self):
|
||||
'''
|
||||
copy of super()._create_vertex_list but with additional colors'''
|
||||
self._vertex_list = self._group.program.vertex_list(
|
||||
6, self._draw_mode, self._batch, self._group,
|
||||
position=('f', self._get_vertices()),
|
||||
colors=('Bn', self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1 ),
|
||||
translation=('f', (self._x, self._y) * self._num_verts))
|
||||
|
||||
def _update_colors(self):
|
||||
self._vertex_list.colors[:] = self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1
|
||||
|
||||
def color1(self, color):
|
||||
r, g, b, *a = color
|
||||
self._rgba1 = (r, g, b, a[0] if a else 255)
|
||||
self._update_colors()
|
||||
|
||||
def color2(self, color):
|
||||
r, g, b, *a = color
|
||||
self._rgba2 = (r, g, b, a[0] if a else 255)
|
||||
self._update_colors()
|
||||
|
||||
self.gradientLine = GradientLine
|
||||
|
||||
def init_labels(self):
|
||||
base_color = (255,)*4
|
||||
color_predictor = (255,255,0, 255)
|
||||
color_info = (255,0, 255, 255)
|
||||
color_tracker = (0,255, 255, 255)
|
||||
|
||||
options = []
|
||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
options.append(f"{option}: {self.config.__dict__[option]}")
|
||||
|
||||
self.labels = {
|
||||
'waiting': pyglet.text.Label("Waiting for prediction"),
|
||||
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
|
||||
'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
|
||||
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
|
||||
}
|
||||
|
||||
def refresh_labels(self, dt: float):
|
||||
"""Every frame"""
|
||||
|
||||
|
@ -173,130 +68,6 @@ class CvRenderer:
|
|||
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
|
||||
# self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
||||
|
||||
|
||||
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
# if prediction_frame:
|
||||
# # render Δt and Δ frames
|
||||
# cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()]):.2f}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
|
||||
# options = []
|
||||
# for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
# options.append(f"{option}: {config.__dict__[option]}")
|
||||
|
||||
|
||||
# cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
|
||||
|
||||
def check_frames(self, dt):
|
||||
new_tracks = False
|
||||
try:
|
||||
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
if not self.first_time:
|
||||
self.first_time = self.frame.time
|
||||
img = cv2.GaussianBlur(self.frame.img, (15, 15), 0)
|
||||
img = cv2.flip(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 0)
|
||||
img = pyglet.image.ImageData(self.frame_size[0], self.frame_size[1], 'RGB', img.tobytes())
|
||||
# don't draw in batch, so that it is the background
|
||||
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
|
||||
self.video_sprite.opacity = 100
|
||||
except zmq.ZMQError as e:
|
||||
# idx = frame.index if frame else "NONE"
|
||||
# logger.debug(f"reuse video frame {idx}")
|
||||
pass
|
||||
try:
|
||||
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
try:
|
||||
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def on_key_press(self, symbol, modifiers):
|
||||
print('A key was pressed, use f to hide')
|
||||
if symbol == ord('f'):
|
||||
self.window.set_fullscreen(not self.window.fullscreen)
|
||||
if symbol == ord('h'):
|
||||
self.hide_stats = not self.hide_stats
|
||||
|
||||
def check_running(self, dt):
|
||||
if not self.is_running.is_set():
|
||||
self.window.close()
|
||||
self.event_loop.exit()
|
||||
|
||||
def on_close(self):
|
||||
self.is_running.clear()
|
||||
|
||||
def on_refresh(self, dt: float):
|
||||
# update shapes
|
||||
# self.bg =
|
||||
for track_id, track in self.drawn_tracks.items():
|
||||
track.update_drawn_positions(dt)
|
||||
|
||||
|
||||
self.refresh_labels(dt)
|
||||
|
||||
# self.shape1 = shapes.Circle(700, 150, 100, color=(50, 0, 30), batch=self.batch_anim)
|
||||
# self.shape3 = shapes.Circle(800, 150, 100, color=(100, 225, 30), batch=self.batch_anim)
|
||||
pass
|
||||
|
||||
def on_draw(self):
|
||||
self.window.clear()
|
||||
|
||||
self.batch_bg.draw()
|
||||
|
||||
for track in self.drawn_tracks.values():
|
||||
for shape in track.shapes:
|
||||
shape.draw() # for some reason the batches don't work
|
||||
for track in self.drawn_tracks.values():
|
||||
for shapes in track.pred_shapes:
|
||||
for shape in shapes:
|
||||
shape.draw()
|
||||
# self.batch_anim.draw()
|
||||
self.batch_overlay.draw()
|
||||
|
||||
|
||||
# pyglet.graphics.draw(3, pyglet.gl.GL_LINE, ("v2i", (100,200, 600,800)), ('c3B', (255,255,255, 255,255,255)))
|
||||
|
||||
|
||||
|
||||
if not self.hide_stats:
|
||||
self.fps_display.draw()
|
||||
|
||||
# if streaming, capture buffer and send
|
||||
try:
|
||||
if self.streaming_process or self.out_writer:
|
||||
buf = pyglet.image.get_buffer_manager().get_color_buffer()
|
||||
img_data = buf.get_image_data()
|
||||
data = img_data.get_data() # alternative: .get_data("RGBA", image_data.pitch)
|
||||
img = np.asanyarray(data).reshape((img_data.height, img_data.width, 4))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
img = np.flip(img, 0)
|
||||
# img = cv2.flip(img, cv2.0)
|
||||
|
||||
# cv2.imshow('frame', img)
|
||||
# cv2.waitKey(1)
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.write(img.tobytes())
|
||||
if self.out_writer:
|
||||
self.out_writer.write(img)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def start_writer(self):
|
||||
if not self.config.output_dir.exists():
|
||||
raise FileNotFoundError("Path does not exist")
|
||||
|
@ -305,16 +76,16 @@ class CvRenderer:
|
|||
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
|
||||
logger.info(f"Write to {filename}")
|
||||
|
||||
return FrameWriter(str(filename), self.fps, self.frame_size)
|
||||
return FrameWriter(str(filename), self.fps, None)
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||
# fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||
|
||||
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
||||
# return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
||||
|
||||
def start_streaming(self):
|
||||
def start_streaming(self, frame_size=(1920,1080)):
|
||||
return (
|
||||
ffmpeg
|
||||
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
|
||||
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*frame_size))
|
||||
.output(
|
||||
self.config.render_url,
|
||||
#codec = "copy", # use same codecs of the original video
|
||||
|
@ -334,10 +105,7 @@ class CvRenderer:
|
|||
)
|
||||
# return process
|
||||
|
||||
|
||||
|
||||
|
||||
def run(self, timer_counter):
|
||||
def run(self):
|
||||
frame = None
|
||||
prediction_frame = None
|
||||
tracker_frame = None
|
||||
|
@ -348,14 +116,12 @@ class CvRenderer:
|
|||
cv2.namedWindow("frame", cv2.WINDOW_NORMAL)
|
||||
# https://gist.github.com/ronekko/dc3747211543165108b11073f929b85e
|
||||
cv2.moveWindow("frame", 1920, -1)
|
||||
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
|
||||
bgsub = cv2.createBackgroundSubtractorMOG2(120, 50, detectShadows=True)
|
||||
|
||||
while self.is_running.is_set():
|
||||
i+=1
|
||||
with timer_counter.get_lock():
|
||||
timer_counter.value+=1
|
||||
if self.config.full_screen:
|
||||
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
|
||||
# bgsub = cv2.createBackgroundSubtractorMOG2(120, 50, detectShadows=True)
|
||||
|
||||
while self.run_loop():
|
||||
i += 1
|
||||
|
||||
# zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
# if not zmq_ev:
|
||||
|
@ -397,7 +163,7 @@ class CvRenderer:
|
|||
first_time = frame.time
|
||||
|
||||
# img = frame.img
|
||||
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters, self.counter_listener,bgsub)
|
||||
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters)
|
||||
|
||||
logger.debug(f"write frame {frame.time - first_time:.3f}s")
|
||||
if self.out_writer:
|
||||
|
@ -431,6 +197,50 @@ class CvRenderer:
|
|||
self.streaming_process.wait()
|
||||
|
||||
logger.info('stopped')
|
||||
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls):
|
||||
render_parser = ArgumentParser()
|
||||
render_parser.add_argument('--zmq-frame-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame")
|
||||
render_parser.add_argument('--zmq-trajectory-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
render_parser.add_argument('--zmq-prediction-addr',
|
||||
help='Manually specity communication addr for the prediction messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_preds")
|
||||
|
||||
render_parser.add_argument("--render-file",
|
||||
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
||||
action='store_true')
|
||||
render_parser.add_argument("--render-window",
|
||||
help="Render a previewing to a window",
|
||||
action='store_true')
|
||||
|
||||
render_parser.add_argument("--full-screen",
|
||||
help="Set Window full screen",
|
||||
action='store_true')
|
||||
render_parser.add_argument("--render-clusters",
|
||||
help="renders arrowd clusters instead of individual predictions",
|
||||
action='store_true')
|
||||
|
||||
render_parser.add_argument("--render-url",
|
||||
help="""Stream renderer on given URL. Two easy approaches:
|
||||
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
|
||||
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
|
||||
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
|
||||
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
|
||||
""",
|
||||
type=str,
|
||||
default=None)
|
||||
|
||||
return render_parser
|
||||
|
||||
# colorset = itertools.product([0,255], repeat=3) # but remove white
|
||||
# colorset = [(0, 0, 0),
|
||||
# (0, 0, 255),
|
||||
|
@ -460,7 +270,7 @@ def get_animation_position(track: Track, current_frame: Frame):
|
|||
|
||||
|
||||
|
||||
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True, counter_listener: CounterListerner|None = None, bg_subtractor = None) -> np.array:
|
||||
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True) -> np.array:
|
||||
scale = 100
|
||||
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
|
||||
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
|
||||
|
@ -503,7 +313,7 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
|
|||
else:
|
||||
for track_id, track in tracks.items():
|
||||
inv_H = np.linalg.pinv(tracker_frame.H)
|
||||
draw_track_projected(img, track, int(track_id), config.camera, conversion)
|
||||
draw_track_projected(img, track, int(track_id), frame.camera, conversion)
|
||||
|
||||
if not prediction_frame:
|
||||
cv2.putText(img, f"Waiting for prediction...", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
|
@ -514,7 +324,7 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
|
|||
# For debugging:
|
||||
# draw_trackjectron_history(img, track, int(track.track_id), conversion)
|
||||
anim_position = get_animation_position(track, frame)
|
||||
draw_track_predictions(img, track, int(track.track_id)+1, config.camera, conversion, anim_position=anim_position, as_clusters=as_clusters)
|
||||
draw_track_predictions(img, track, int(track.track_id)+1, frame.camera, conversion, anim_position=anim_position, as_clusters=as_clusters)
|
||||
cv2.putText(img, f"{len(track.predictor_history) if track.predictor_history else 'none'}", to_point(track.history[0].get_foot_coords()), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,255), 1)
|
||||
if prediction_frame.maps:
|
||||
for i, m in enumerate(prediction_frame.maps):
|
||||
|
@ -544,6 +354,8 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
|
|||
cv2.putText(img, f"{frame.time - first_time: >10.2f}s", (150,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
cv2.putText(img, f"{frame.time - time.time():.2f}s", (250,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
options = []
|
||||
|
||||
if prediction_frame:
|
||||
# render Δt and Δ frames
|
||||
cv2.putText(img, f"{tracker_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, tracker_color, 1)
|
||||
|
@ -554,19 +366,15 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
|
|||
cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()]):.2f}", (700,17), cv2.FONT_HERSHEY_PLAIN, 1, tracker_color, 1)
|
||||
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (780,17), cv2.FONT_HERSHEY_PLAIN, 1, predictor_color, 1)
|
||||
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (860,17), cv2.FONT_HERSHEY_PLAIN, 1, predictor_color, 1)
|
||||
|
||||
options = []
|
||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
options.append(f"{option}: {config.__dict__[option]}")
|
||||
|
||||
|
||||
|
||||
for option, value in prediction_frame.log['predictor'].items():
|
||||
options.append(f"{option}: {value}")
|
||||
|
||||
|
||||
cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
# for i, (k, v) in enumerate(counter_listener.get_latest().items()):
|
||||
# cv2.putText(img, f"{k} {v.value()}", (20,img.shape[0]-(40*i)-40), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
|
||||
return img
|
||||
|
||||
|
||||
|
|
|
@ -1,75 +1,51 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
import dataclasses
|
||||
from enum import IntFlag
|
||||
from itertools import cycle
|
||||
import json
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List, Optional
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pandas as pd
|
||||
import zmq
|
||||
import os
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
|
||||
from bytetracker.byte_tracker import STrack as ByteTrackTrack
|
||||
from bytetracker.basetrack import TrackState as ByteTrackTrackState
|
||||
from trajectron.environment import Environment, Node, Scene
|
||||
from urllib.parse import urlparse
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from multiprocessing import Event
|
||||
from pathlib import Path
|
||||
|
||||
import zmq
|
||||
|
||||
from trap import node
|
||||
from trap.base import *
|
||||
from trap.base import LambdaParser
|
||||
from trap.timer import Timer
|
||||
from trap.utils import get_bins
|
||||
from trap.utils import inv_lerp, lerp
|
||||
from trap.video_sources import get_video_source
|
||||
|
||||
logger = logging.getLogger('trap.frame_emitter')
|
||||
|
||||
|
||||
|
||||
|
||||
class FrameEmitter:
|
||||
class FrameEmitter(node.Node):
|
||||
'''
|
||||
Emit frame in a separate threat so they can be throttled,
|
||||
or thrown away when the rest of the system cannot keep up
|
||||
'''
|
||||
def __init__(self, config: Namespace, is_running: Event) -> None:
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
context = zmq.Context()
|
||||
# TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859
|
||||
self.frame_sock = context.socket(zmq.PUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
||||
self.frame_sock.bind(config.zmq_frame_addr)
|
||||
|
||||
self.frame_noimg_sock = context.socket(zmq.PUB)
|
||||
self.frame_noimg_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
||||
self.frame_noimg_sock.bind(config.zmq_frame_noimg_addr)
|
||||
def setup(self) -> None:
|
||||
self.frame_sock = self.pub(self.config.zmq_frame_addr)
|
||||
self.frame_noimg_sock = self.pub(self.config.zmq_frame_noimg_addr)
|
||||
|
||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||
logger.info(f"Connection socket {config.zmq_frame_noimg_addr}")
|
||||
logger.info(f"Connection socket {self.config.zmq_frame_addr}")
|
||||
logger.info(f"Connection socket {self.config.zmq_frame_noimg_addr}")
|
||||
|
||||
self.video_srcs = self.config.video_src
|
||||
|
||||
|
||||
|
||||
def emit_video(self, timer_counter):
|
||||
i = 0
|
||||
|
||||
source = get_video_source(self.video_srcs, self.config.camera, int(self.config.video_offset or 0), self.config.video_end, self.config.video_loop)
|
||||
for i, img in enumerate(source):
|
||||
def run(self):
|
||||
offset = int(self.config.video_offset or 0)
|
||||
source = get_video_source(self.video_srcs, self.config.camera, offset, self.config.video_end, self.config.video_loop)
|
||||
video_gen = enumerate(source, start = offset)
|
||||
while self.run_loop():
|
||||
|
||||
with timer_counter.get_lock():
|
||||
timer_counter.value += 1
|
||||
try:
|
||||
i, img = next(video_gen)
|
||||
except StopIteration as e:
|
||||
logger.info("Video source ended")
|
||||
break
|
||||
|
||||
frame = Frame(i, img=img, H=self.config.camera.H, camera=self.config.camera)
|
||||
|
||||
|
@ -78,73 +54,62 @@ class FrameEmitter:
|
|||
self.frame_noimg_sock.send(pickle.dumps(frame.without_img()))
|
||||
self.frame_sock.send(pickle.dumps(frame))
|
||||
|
||||
if not self.is_running.is_set():
|
||||
# if not running, also break out of infinite generator loop
|
||||
break
|
||||
|
||||
|
||||
logger.info("Stopping")
|
||||
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls) -> ArgumentParser:
|
||||
argparser = LambdaParser()
|
||||
argparser.add_argument('--zmq-frame-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame")
|
||||
|
||||
argparser.add_argument('--zmq-frame-noimg-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame2")
|
||||
|
||||
argparser.add_argument("--video-src",
|
||||
help="source video to track from can be either a relative or absolute path, or a url, like an RTSP resource, or use gige://RELATIVE_PATH_TO_GIGE_CONFIG_JSON",
|
||||
type=UrlOrPath,
|
||||
nargs='+',
|
||||
default=lambda: [UrlOrPath(p) for p in Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')])
|
||||
argparser.add_argument("--video-offset",
|
||||
help="Start playback from given frame. Note that when src is an array, this applies to all videos individually.",
|
||||
default=0,
|
||||
type=int)
|
||||
argparser.add_argument("--video-end",
|
||||
help="End (or loop) playback at given frame.",
|
||||
default=None,
|
||||
type=int)
|
||||
|
||||
argparser.add_argument("--video-loop",
|
||||
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
||||
action='store_true')
|
||||
|
||||
argparser.add_argument("--camera-fps",
|
||||
help="Camera FPS",
|
||||
type=int,
|
||||
default=12)
|
||||
argparser.add_argument("--homography",
|
||||
help="File with homography params [Deprecated]",
|
||||
type=Path,
|
||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
|
||||
action=HomographyAction)
|
||||
argparser.add_argument("--calibration",
|
||||
help="File with camera intrinsics and lens distortion params (calibration.json)",
|
||||
# type=Path,
|
||||
required=True,
|
||||
# default=None,
|
||||
action=CameraAction)
|
||||
return argparser
|
||||
|
||||
|
||||
|
||||
|
||||
def run_frame_emitter(config: Namespace, is_running: Event, timer_counter: int):
|
||||
router = FrameEmitter(config, is_running)
|
||||
router.emit_video(timer_counter)
|
||||
router.run(timer_counter)
|
||||
is_running.clear()
|
||||
|
||||
def run():
|
||||
# Frame emitter
|
||||
import argparse
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--zmq-frame-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame")
|
||||
|
||||
argparser.add_argument('--zmq-frame-noimg-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame2")
|
||||
|
||||
argparser.add_argument("--video-src",
|
||||
help="source video to track from can be either a relative or absolute path, or a url, like an RTSP resource",
|
||||
type=UrlOrPath,
|
||||
nargs='+',
|
||||
default=lambda: [UrlOrPath(p) for p in Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')])
|
||||
argparser.add_argument("--video-offset",
|
||||
help="Start playback from given frame. Note that when src is an array, this applies to all videos individually.",
|
||||
default=0,
|
||||
type=int)
|
||||
#TODO: camera as source
|
||||
|
||||
argparser.add_argument("--video-loop",
|
||||
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
||||
action='store_true')
|
||||
#TODO: camera as source
|
||||
|
||||
|
||||
# Tracker
|
||||
|
||||
argparser.add_argument("--camera-fps",
|
||||
help="Camera FPS",
|
||||
type=int,
|
||||
default=12)
|
||||
argparser.add_argument("--homography",
|
||||
help="File with homography params",
|
||||
type=Path,
|
||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
|
||||
action=HomographyAction)
|
||||
argparser.add_argument("--calibration",
|
||||
help="File with camera intrinsics and lens distortion params (calibration.json)",
|
||||
# type=Path,
|
||||
required=True,
|
||||
# default=None,
|
||||
action=CameraAction)
|
||||
config = argparser.parse_args()
|
||||
is_running = multiprocessing.Event()
|
||||
print(is_running.set())
|
||||
timer_counter = Timer('frame_emitter')
|
||||
|
||||
router = FrameEmitter(config, is_running)
|
||||
router.emit_video(timer_counter.iterations)
|
||||
is_running.clear()
|
37
trap/node.py
37
trap/node.py
|
@ -1,18 +1,20 @@
|
|||
import logging
|
||||
import multiprocessing
|
||||
from multiprocessing.synchronize import Event as BaseEvent
|
||||
from argparse import Namespace
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
|
||||
from trap.counter import CounterFpsSender, CounterSender
|
||||
from trap.timer import Timer
|
||||
|
||||
|
||||
class Node():
|
||||
def __init__(self, config: Namespace, is_running: BaseEvent, timer_counter: Timer):
|
||||
def __init__(self, config: Namespace, is_running: BaseEvent, fps_counter: CounterFpsSender):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
self.timer_counter = timer_counter
|
||||
self.fps_counter = fps_counter
|
||||
self.zmq_context = zmq.Context()
|
||||
self.logger = self._logger()
|
||||
|
||||
|
@ -23,8 +25,9 @@ class Node():
|
|||
return logging.getLogger(f"trap.{cls.__name__}")
|
||||
|
||||
def tick(self):
|
||||
with self.timer_counter.get_lock():
|
||||
self.timer_counter.value+=1
|
||||
self.fps_counter.tick()
|
||||
# with self.fps_counter.get_lock():
|
||||
# self.fps_counter.value+=1
|
||||
|
||||
def setup(self):
|
||||
raise RuntimeError("Not implemented setup()")
|
||||
|
@ -32,6 +35,17 @@ class Node():
|
|||
def run(self):
|
||||
raise RuntimeError("Not implemented run()")
|
||||
|
||||
def run_loop(self):
|
||||
"""Use in run(), to check if it should keep looping
|
||||
Takes care of tick()'ing the iterations/second counter
|
||||
"""
|
||||
self.tick()
|
||||
return self.is_running.is_set()
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls) -> ArgumentParser:
|
||||
raise RuntimeError("Not implemented arg_parser()")
|
||||
|
||||
def sub(self, addr: str):
|
||||
"Default zmq sub configuration"
|
||||
sock = self.zmq_context.socket(zmq.SUB)
|
||||
|
@ -51,4 +65,15 @@ class Node():
|
|||
def start(cls, config: Namespace, is_running: BaseEvent, timer_counter: Optional[Timer]):
|
||||
instance = cls(config, is_running, timer_counter)
|
||||
instance.run()
|
||||
instance.logger.info("Stopping")
|
||||
instance.logger.info("Stopping")
|
||||
|
||||
@classmethod
|
||||
def parse_and_start(cls):
|
||||
config = cls.arg_parser().parse_args()
|
||||
is_running = multiprocessing.Event()
|
||||
is_running.set()
|
||||
statsender = CounterSender()
|
||||
counter = CounterFpsSender(f"trap.{cls.__name__}", statsender)
|
||||
# timer_counter = Timer(cls.__name__)
|
||||
|
||||
cls.start(config, is_running, counter)
|
|
@ -1,33 +1,27 @@
|
|||
# adapted from Trajectron++ online_server.py
|
||||
from argparse import Namespace
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import traceback
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import torch
|
||||
import dill
|
||||
import random
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import numpy as np
|
||||
from trajectron.environment.data_utils import derivative_of
|
||||
from trajectron.utils import prediction_output_to_trajectories
|
||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||
from trajectron.model.model_registrar import ModelRegistrar
|
||||
from trajectron.environment import Environment, Scene
|
||||
from trajectron.environment.node import Node
|
||||
from trajectron.environment.node_type import NodeType
|
||||
import matplotlib.pyplot as plt
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from multiprocessing import Event
|
||||
|
||||
import dill
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from trajectron.environment import Environment, Scene
|
||||
from trajectron.model.model_registrar import ModelRegistrar
|
||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||
from trajectron.utils import prediction_output_to_trajectories
|
||||
|
||||
from trap.frame_emitter import DataclassJSONEncoder, Frame
|
||||
from trap.tracker import Track, Smoother
|
||||
from trap.node import Node
|
||||
from trap.tracker import Smoother
|
||||
|
||||
logger = logging.getLogger("trap.prediction")
|
||||
|
||||
|
@ -146,27 +140,18 @@ def offset_trajectron_dict(source, x, y):
|
|||
source[t][node][:,1] += y
|
||||
return source
|
||||
|
||||
class PredictionServer:
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
class PredictionServer(Node):
|
||||
def setup(self):
|
||||
if self.config.eval_device == 'cpu':
|
||||
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
|
||||
|
||||
if self.config.smooth_predictions:
|
||||
self.smoother = Smoother(window_len=12, convolution=True) # convolution seems fine for predictions
|
||||
|
||||
context = zmq.Context()
|
||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
|
||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||
|
||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||
self.trajectory_socket = self.sub(self.config.zmq_trajectory_addr)
|
||||
self.prediction_socket = self.pub(self.config.zmq_prediction_addr)
|
||||
self.external_predictions = not self.config.zmq_prediction_addr.startswith("ipc://")
|
||||
# print(self.prediction_socket)
|
||||
|
||||
|
||||
def send_frame(self, frame: Frame):
|
||||
if self.external_predictions:
|
||||
|
@ -175,8 +160,7 @@ class PredictionServer:
|
|||
else:
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
|
||||
def run(self, timer_counter):
|
||||
print(self.config)
|
||||
def run(self):
|
||||
if self.config.seed is not None:
|
||||
random.seed(self.config.seed)
|
||||
np.random.seed(self.config.seed)
|
||||
|
@ -250,17 +234,9 @@ class PredictionServer:
|
|||
trajectron.set_environment(online_env, init_timestep)
|
||||
|
||||
timestep = init_timestep + 1
|
||||
prev_run_time = 0
|
||||
while self.is_running.is_set():
|
||||
timestep += 1
|
||||
with timer_counter.get_lock():
|
||||
timer_counter.value+=1
|
||||
|
||||
# this_run_time = time.time()
|
||||
# logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
# time.sleep(max(0, prev_run_time - this_run_time + .5))
|
||||
# prev_run_time = time.time()
|
||||
|
||||
while self.run_loop():
|
||||
|
||||
|
||||
# TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns
|
||||
# data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
|
||||
|
@ -283,7 +259,6 @@ class PredictionServer:
|
|||
if self.config.predict_training_data:
|
||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||
else:
|
||||
# print('await', self.config.zmq_trajectory_addr)
|
||||
zmq_ev = self.trajectory_socket.poll(timeout=2000)
|
||||
if not zmq_ev:
|
||||
# on no data loop so that is_running is checked
|
||||
|
@ -295,6 +270,12 @@ class PredictionServer:
|
|||
# print('recv tracker frame')
|
||||
frame: Frame = pickle.loads(data)
|
||||
|
||||
# add settings to log
|
||||
frame.log['predictor'] = {}
|
||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
frame.log['predictor'][option] = self.config.__dict__[option]
|
||||
|
||||
|
||||
# print('indexrecv', [frame.tracks[t].frame_index for t in frame.tracks])
|
||||
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
||||
# trajectory_data = json.loads(data)
|
||||
|
@ -489,6 +470,160 @@ class PredictionServer:
|
|||
self.send_frame(frame)
|
||||
|
||||
logger.info('Stopping')
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls) -> ArgumentParser:
|
||||
inference_parser = ArgumentParser()
|
||||
inference_parser.add_argument('--zmq-trajectory-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
inference_parser.add_argument('--zmq-prediction-addr',
|
||||
help='Manually specity communication addr for the prediction messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_preds")
|
||||
|
||||
|
||||
inference_parser.add_argument("--step-size",
|
||||
# TODO)) Make dataset/model metadata
|
||||
help="sample step size (should be the same as for data processing and augmentation)",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
inference_parser.add_argument("--model_dir",
|
||||
help="directory with the model to use for inference",
|
||||
type=str, # TODO: make into Path
|
||||
default='../Trajectron-plus-plus/experiments/trap/models/models_18_Oct_2023_19_56_22_virat_vel_ar3/')
|
||||
# default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
|
||||
|
||||
inference_parser.add_argument("--conf",
|
||||
help="path to json config file for hyperparameters, relative to model_dir",
|
||||
type=str,
|
||||
default='config.json')
|
||||
|
||||
# Model Parameters (hyperparameters)
|
||||
inference_parser.add_argument("--offline_scene_graph",
|
||||
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
|
||||
type=str,
|
||||
default='yes')
|
||||
|
||||
inference_parser.add_argument("--dynamic_edges",
|
||||
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
|
||||
type=str,
|
||||
default='yes')
|
||||
|
||||
inference_parser.add_argument("--edge_state_combine_method",
|
||||
help="the method to use for combining edges of the same type",
|
||||
type=str,
|
||||
default='sum')
|
||||
|
||||
inference_parser.add_argument("--edge_influence_combine_method",
|
||||
help="the method to use for combining edge influences",
|
||||
type=str,
|
||||
default='attention')
|
||||
|
||||
inference_parser.add_argument('--edge_addition_filter',
|
||||
nargs='+',
|
||||
help="what scaling to use for edges as they're created",
|
||||
type=float,
|
||||
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
|
||||
# and short edge addition, then you need to have a 0.0 at the
|
||||
# beginning, e.g. [0.0, 1.0].
|
||||
|
||||
inference_parser.add_argument('--edge_removal_filter',
|
||||
nargs='+',
|
||||
help="what scaling to use for edges as they're removed",
|
||||
type=float,
|
||||
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
|
||||
# the default, then you need to have a 0.0 at the end.
|
||||
|
||||
|
||||
inference_parser.add_argument('--incl_robot_node',
|
||||
help="whether to include a robot node in the graph or simply model all agents",
|
||||
action='store_true')
|
||||
|
||||
inference_parser.add_argument('--map_encoding',
|
||||
help="Whether to use map encoding or not",
|
||||
action='store_true')
|
||||
|
||||
inference_parser.add_argument('--no_edge_encoding',
|
||||
help="Whether to use neighbors edge encoding",
|
||||
action='store_true')
|
||||
|
||||
|
||||
inference_parser.add_argument('--batch_size',
|
||||
help='training batch size',
|
||||
type=int,
|
||||
default=256)
|
||||
|
||||
inference_parser.add_argument('--k_eval',
|
||||
help='how many samples to take during evaluation',
|
||||
type=int,
|
||||
default=25)
|
||||
|
||||
# Data Parameters
|
||||
inference_parser.add_argument("--eval_data_dict",
|
||||
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
|
||||
type=str,
|
||||
default='../Trajectron-plus-plus/experiments/processed/eth_test.pkl')
|
||||
|
||||
inference_parser.add_argument("--output_dir",
|
||||
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
||||
type=pathlib.Path,
|
||||
default='./OUT/test_inference')
|
||||
|
||||
|
||||
# inference_parser.add_argument('--device',
|
||||
# help='what device to perform training on',
|
||||
# type=str,
|
||||
# default='cuda:0')
|
||||
|
||||
inference_parser.add_argument("--eval_device",
|
||||
help="what device to use during inference",
|
||||
type=str,
|
||||
default="cpu")
|
||||
|
||||
|
||||
inference_parser.add_argument('--seed',
|
||||
help='manual seed to use, default is 123',
|
||||
type=int,
|
||||
default=123)
|
||||
|
||||
inference_parser.add_argument('--predict_training_data',
|
||||
help='Ignore tracker and predict data from the training dataset',
|
||||
action='store_true')
|
||||
|
||||
inference_parser.add_argument("--smooth-predictions",
|
||||
help="Smooth the predicted tracks",
|
||||
action='store_true')
|
||||
|
||||
inference_parser.add_argument('--prediction-horizon',
|
||||
help='Trajectron.incremental_forward parameter',
|
||||
type=int,
|
||||
default=30)
|
||||
inference_parser.add_argument('--num-samples',
|
||||
help='Trajectron.incremental_forward parameter',
|
||||
type=int,
|
||||
default=5)
|
||||
inference_parser.add_argument("--full-dist",
|
||||
help="Trajectron.incremental_forward parameter",
|
||||
action='store_true')
|
||||
inference_parser.add_argument("--gmm-mode",
|
||||
help="Trajectron.incremental_forward parameter",
|
||||
type=bool,
|
||||
default=True)
|
||||
inference_parser.add_argument("--z-mode",
|
||||
help="Trajectron.incremental_forward parameter",
|
||||
action='store_true')
|
||||
inference_parser.add_argument('--cm-to-m',
|
||||
help="Correct for homography that is in cm (i.e. {x,y}/100). Should also be used when processing data",
|
||||
action='store_true')
|
||||
inference_parser.add_argument('--center-data',
|
||||
help="Center data around cx and cy. Should also be used when processing data",
|
||||
action='store_true')
|
||||
|
||||
|
||||
return inference_parser
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -298,7 +298,7 @@ class FrameWriter:
|
|||
framerate.
|
||||
See https://video.stackexchange.com/questions/25811/ffmpeg-make-video-with-non-constant-framerate-from-image-filenames
|
||||
"""
|
||||
def __init__(self, filename: str, fps: float, frame_size: tuple) -> None:
|
||||
def __init__(self, filename: str, fps: float, frame_size: Optional[tuple] = None) -> None:
|
||||
self.filename = filename
|
||||
self.fps = fps
|
||||
self.frame_size = frame_size
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
@ -870,4 +871,21 @@ class Stage(Node):
|
|||
|
||||
# print(json.dumps(rl, cls=DataclassJSONEncoder))
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls) -> ArgumentParser:
|
||||
argparser = ArgumentParser()
|
||||
argparser.add_argument('--zmq-trajectory-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
argparser.add_argument('--zmq-prediction-addr',
|
||||
help='Manually specity communication addr for the prediction messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_preds")
|
||||
argparser.add_argument('--zmq-stage-addr',
|
||||
help='Manually specity communication addr for the stage messages (the rendered lines)',
|
||||
type=str,
|
||||
default="tcp://0.0.0.0:99174")
|
||||
return argparser
|
||||
|
||||
|
||||
|
|
124
trap/tracker.py
124
trap/tracker.py
|
@ -1,39 +1,40 @@
|
|||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
import argparse
|
||||
import csv
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
from math import nan
|
||||
from multiprocessing import Event
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import DefaultDict, Dict, Optional, List
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from multiprocessing import Event
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import ultralytics
|
||||
import zmq
|
||||
import cv2
|
||||
|
||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights, FasterRCNN_ResNet50_FPN_V2_Weights, fasterrcnn_resnet50_fpn_v2
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
from torchvision.models import ResNet50_Weights
|
||||
from bytetracker import BYTETracker
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
|
||||
KeypointRCNN_ResNet50_FPN_Weights,
|
||||
MaskRCNN_ResNet50_FPN_V2_Weights,
|
||||
fasterrcnn_resnet50_fpn_v2,
|
||||
keypointrcnn_resnet50_fpn,
|
||||
maskrcnn_resnet50_fpn_v2)
|
||||
from tsmoothie.smoother import ConvolutionSmoother, KalmanSmoother
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
|
||||
from trap import timer
|
||||
from trap.frame_emitter import Camera, DataclassJSONEncoder, DetectionState, Frame, Detection, Track
|
||||
from bytetracker import BYTETracker
|
||||
|
||||
from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
|
||||
import tsmoothie.smoother
|
||||
from datetime import datetime, timedelta
|
||||
from trap.frame_emitter import (Camera, DataclassJSONEncoder, Detection,
|
||||
DetectionState, Frame, Track)
|
||||
from trap.node import Node
|
||||
|
||||
# Detection = [int, int, int, int, float, int]
|
||||
# Detections = [Detection]
|
||||
|
@ -387,9 +388,8 @@ class ByteTrackWrapper(TrackerWrapper):
|
|||
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, config: Namespace):
|
||||
self.config = config
|
||||
class Tracker(Node):
|
||||
def setup(self):
|
||||
|
||||
|
||||
# # TODO: config device
|
||||
|
@ -453,6 +453,9 @@ class Tracker:
|
|||
logger.info("Smoother Disabled (enable with --smooth-tracks)")
|
||||
|
||||
|
||||
self.frame_sock = self.sub(self.config.zmq_frame_addr)
|
||||
self.trajectory_socket = self.pub(self.config.zmq_trajectory_addr)
|
||||
|
||||
logger.debug("Set up tracker")
|
||||
|
||||
def track_frame(self, frame: Frame):
|
||||
|
@ -474,42 +477,11 @@ class Tracker:
|
|||
return detections
|
||||
|
||||
|
||||
def track(self, is_running: Event, timer_counter: int = 0):
|
||||
def run(self):
|
||||
"""
|
||||
Live tracking of frames coming in over zmq
|
||||
"""
|
||||
|
||||
self.is_running = is_running
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.connect(self.config.zmq_frame_addr)
|
||||
|
||||
self.trajectory_socket = context.socket(zmq.PUB)
|
||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||
self.trajectory_socket.bind(self.config.zmq_trajectory_addr)
|
||||
|
||||
prev_run_time = 0
|
||||
|
||||
# training_fp = None
|
||||
# training_csv = None
|
||||
# training_frames = 0
|
||||
|
||||
# if self.config.save_for_training is not None:
|
||||
# if not isinstance(self.config.save_for_training, Path):
|
||||
# raise ValueError("save-for-training should be a path")
|
||||
# if not self.config.save_for_training.exists():
|
||||
# logger.info(f"Making path for training data: {self.config.save_for_training}")
|
||||
# self.config.save_for_training.mkdir(parents=True, exist_ok=False)
|
||||
# else:
|
||||
# logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
|
||||
# training_fp = open(self.config.save_for_training / 'all.txt', 'w')
|
||||
# # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
# training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
|
||||
prev_frame_i = -1
|
||||
|
||||
with TrainingDataWriter(self.config.save_for_training) as writer:
|
||||
|
@ -518,9 +490,6 @@ class Tracker:
|
|||
w_time = None
|
||||
displacement_filter = FinalDisplacementFilter(.8)
|
||||
while self.is_running.is_set():
|
||||
|
||||
with timer_counter.get_lock():
|
||||
timer_counter.value += 1
|
||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||
# so for now, timing should move to emitter
|
||||
|
@ -532,10 +501,12 @@ class Tracker:
|
|||
poll_time = time.time()
|
||||
zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
if not zmq_ev:
|
||||
logger.warning('skip poll after 2000ms')
|
||||
logger.warning('no frame for 2000ms')
|
||||
# when there's no data after timeout, loop so that is_running is checked
|
||||
continue
|
||||
|
||||
self.tick() # only tick if something is actually received
|
||||
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
|
@ -696,11 +667,42 @@ class Tracker:
|
|||
different nesting
|
||||
"""
|
||||
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
|
||||
|
||||
@classmethod
|
||||
def arg_parser(cls):
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--zmq-frame-addr',
|
||||
help='Manually specity communication addr for the frame messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_frame")
|
||||
argparser.add_argument('--zmq-trajectory-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
|
||||
argparser.add_argument("--save-for-training",
|
||||
help="Specify the path in which to save",
|
||||
type=Path,
|
||||
default=None)
|
||||
argparser.add_argument("--detector",
|
||||
help="Specify the detector to use",
|
||||
type=str,
|
||||
default=DETECTOR_YOLOv8,
|
||||
choices=DETECTORS)
|
||||
argparser.add_argument("--tracker",
|
||||
help="Specify the detector to use",
|
||||
type=str,
|
||||
default=TRACKER_BYTETRACK,
|
||||
choices=TRACKERS)
|
||||
argparser.add_argument("--smooth-tracks",
|
||||
help="Smooth the tracker tracks before sending them to the predictor",
|
||||
action='store_true')
|
||||
return argparser
|
||||
|
||||
|
||||
def run_tracker(config: Namespace, is_running: Event, timer_counter):
|
||||
router = Tracker(config)
|
||||
router.track(is_running, timer_counter)
|
||||
router.run(is_running, timer_counter)
|
||||
|
||||
def run():
|
||||
# Frame emitter
|
||||
|
@ -738,7 +740,7 @@ def run():
|
|||
timer_counter = timer.Timer('frame_emitter')
|
||||
|
||||
router = Tracker(config)
|
||||
router.track(is_running, timer_counter.iterations)
|
||||
router.run(is_running, timer_counter.iterations)
|
||||
is_running.clear()
|
||||
|
||||
|
||||
|
|
14
uv.lock
14
uv.lock
|
@ -2191,6 +2191,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "supervisor"
|
||||
version = "4.2.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ce/37/517989b05849dd6eaa76c148f24517544704895830a50289cbbf53c7efb9/supervisor-4.2.5.tar.gz", hash = "sha256:34761bae1a23c58192281a5115fb07fbf22c9b0133c08166beffc70fed3ebc12", size = 466073 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/7a/0ad3973941590c040475046fef37a2b08a76691e61aa59540828ee235a6e/supervisor-4.2.5-py2.py3-none-any.whl", hash = "sha256:2ecaede32fc25af814696374b79e42644ecaba5c09494c51016ffda9602d0f08", size = 319561 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorboard"
|
||||
version = "2.19.0"
|
||||
|
@ -2495,6 +2507,7 @@ dependencies = [
|
|||
{ name = "setproctitle" },
|
||||
{ name = "shapely" },
|
||||
{ name = "simplification" },
|
||||
{ name = "supervisor" },
|
||||
{ name = "tensorboardx" },
|
||||
{ name = "torch", version = "1.12.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
|
||||
{ name = "torch", version = "1.12.1+cu113", source = { url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl" }, marker = "sys_platform == 'linux'" },
|
||||
|
@ -2527,6 +2540,7 @@ requires-dist = [
|
|||
{ name = "setproctitle", specifier = ">=1.3.3,<2" },
|
||||
{ name = "shapely", specifier = ">=2.1" },
|
||||
{ name = "simplification", specifier = ">=0.7.12" },
|
||||
{ name = "supervisor", specifier = ">=4.2.5" },
|
||||
{ name = "tensorboardx", specifier = ">=2.6.2.2,<3" },
|
||||
{ name = "torch", marker = "python_full_version < '3.10' or python_full_version >= '4' or sys_platform != 'linux'", specifier = "==1.12.1" },
|
||||
{ name = "torch", marker = "python_full_version >= '3.10' and python_full_version < '4' and sys_platform == 'linux'", url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl" },
|
||||
|
|
Loading…
Reference in a new issue