Run as separate processes for simple auto restart of individual nodes

This commit is contained in:
Ruben van de Ven 2025-05-19 11:58:37 +02:00
parent 9cc08f09af
commit 907d5a2157
14 changed files with 612 additions and 536 deletions

View file

@ -21,6 +21,7 @@ These are roughly the steps to go from datagathering to training
* `uv run process_data --src-dir EXPERIMENTS/raw/NAME --dst-dir EXPERIMENTS/trajectron-data/ --name NAME --smooth-tracks --camera-fps 12 --homography ../DATASETS/NAME/homography.json --calibration ../DATASETS/NAME/calibration.json --filter-displacement 2 --map-img-path ../DATASETS/NAME/map.png`
5. Train Trajectron model `uv run trajectron_train --eval_every 10 --vis_every 1 --train_data_dict NAME_train.pkl --eval_data_dict NAME_val.pkl --offline_scene_graph no --preprocess_workers 8 --log_dir EXPERIMENTS/models --log_tag _NAME --train_epochs 100 --conf EXPERIMENTS/config.json --batch_size 256 --data_dir EXPERIMENTS/trajectron-data `
6. The run!
* On a video file (you can use a wildcard) `DISPLAY=:1 uv run trapserv --remote-log-addr 100.69.123.91 --eval_device cuda:0 --detector ultralytics --homography ../DATASETS/NAME/homography.json --eval_data_dict EXPERIMENTS/trajectron-data/hof2s-m_test.pkl --video-src ../DATASETS/NAME/*.mp4 --model_dir EXPERIMENTS/models/models_DATE_NAME/--smooth-predictions --smooth-tracks --num-samples 3 --render-window --calibration ../DATASETS/NAME/calibration.json` (the DISPLAY environment variable is used here to running over SSH connection and display on local monitor)
* `uv run supervisord`
<!-- * On a video file (you can use a wildcard) `DISPLAY=:1 uv run trapserv --remote-log-addr 100.69.123.91 --eval_device cuda:0 --detector ultralytics --homography ../DATASETS/NAME/homography.json --eval_data_dict EXPERIMENTS/trajectron-data/hof2s-m_test.pkl --video-src ../DATASETS/NAME/*.mp4 --model_dir EXPERIMENTS/models/models_DATE_NAME/--smooth-predictions --smooth-tracks --num-samples 3 --render-window --calibration ../DATASETS/NAME/calibration.json` (the DISPLAY environment variable is used here to running over SSH connection and display on local monitor)
* or on the RTSP stream. Which uses gstreamer to substantially reduce latency compared to the default ffmpeg bindings in OpenCV.
* To just have a single trajectory pulled from distribution use `--full-dist`. Also try `--z_mode`.
* To just have a single trajectory pulled from distribution use `--full-dist`. Also try `--z_mode`. -->

View file

@ -33,6 +33,7 @@ dependencies = [
"python-statemachine>=2.5.0",
"facenet-pytorch>=2.5.3",
"simplification>=0.7.12",
"supervisor>=4.2.5",
]
[project.scripts]
@ -43,8 +44,13 @@ compare = "trap.tools:tracker_compare"
process_data = "trap.process_data:main"
blacklist = "trap.tools:blacklist_tracks"
rewrite_tracks = "trap.tools:rewrite_raw_track_files"
live_video_source = "trap.frame_emitter:run"
live_tracker = "trap.tracker:run"
trap_video_source = "trap.frame_emitter:FrameEmitter.parse_and_start"
trap_tracker = "trap.tracker:Tracker.parse_and_start"
trap_stage = "trap.stage:Stage.parse_and_start"
trap_prediction = "trap.prediction_server:PredictionServer.parse_and_start"
trap_render_cv = "trap.cv_renderer:CvRenderer.parse_and_start"
trap_monitor = "trap.monitor:Monitor.parse_and_start" # migrate timer
[tool.uv]

48
supervisord.conf Normal file
View file

@ -0,0 +1,48 @@
[inet_http_server]
port = *:8293
# username = user
# password = 123
[supervisord]
nodaemon = True
; The rpcinterface:supervisor section must remain in the config file for
; RPC (supervisorctl/web interface) to work. Additional interfaces may be
; added by defining them in separate [rpcinterface:x] sections.
[rpcinterface:supervisor]
supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
[supervisorctl]
serverurl = http://localhost:8293
[program:monitor]
command=uv run trap_monitor
numprocs=1
directory=%(here)s
[program:video]
command=uv run trap_video_source --homography ../DATASETS/hof3/homography.json --video-src ../DATASETS/hof3/hof3-cam-demo-twoperson.mp4 --calibration ../DATASETS/hof3/calibration.json --video-loop
# command=uv run trap_video_source --homography ../DATASETS/hof3-cam-baumer/homography.json --video-src gige://../DATASETS/hof3-cam-baumer/gige_config.json --calibration ../DATASETS/hof3-cam-baumer/calibration.json
directory=%(here)s
directory=%(here)s
[program:tracker]
command=uv run trap_tracker
directory=%(here)s
[program:stage]
command=uv run trap_stage
directory=%(here)s
[program:predictor]
command=uv run trap_prediction --eval_device cuda:0 --model_dir EXPERIMENTS/models/models_20241229_21_35_13_hof3-m2-ud-split-conv12-f2.0-map-2024-12-29/ --num-samples 1 --map_encoding --eval_data_dict EXPERIMENTS/trajectron-data/hof3-m2-ud-split-nostep-conv12-f2.0-map-2024-12-29_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode
directory=%(here)s
[program:render_cv]
command=uv run trap_render_cv
directory=%(here)s
environment=DISPLAY=":0"
autostart=false
; can be long to quit if rendering to video file
stopwaitsecs=60

View file

@ -2,12 +2,14 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import argparse
from collections import defaultdict
from enum import IntFlag
from itertools import cycle
import json
import logging
from pathlib import Path
import time
import types
from typing import Iterable, Optional, Tuple, Union, List
import cv2
from dataclasses import dataclass, field
@ -617,6 +619,7 @@ class Frame:
H: Optional[np.array] = None
camera: Optional[Camera] = None
maps: Optional[List[cv2.Mat]] = None
log: dict = field(default_factory=lambda: {}) # settings used during processing. All intermediate nodes can store their config here
def aslist(self) -> List[dict]:
return { t.track_id:
@ -724,3 +727,17 @@ class CameraAction(argparse.Action):
# camera = Camera(np.array(data['camera_matrix']), np.array(data['dist_coeff']), data['dim']['width'], data['dim']['height'], namespace.H, namespace.camera_fps)
setattr(namespace, 'camera', camera)
class LambdaParser(argparse.ArgumentParser):
"""Execute lambda functions
"""
def parse_args(self, args=None, namespace=None):
args = super().parse_args(args, namespace)
for key in vars(args):
f = args.__dict__[key]
if type(f) == types.LambdaType:
print(f'Getting default value for {key}')
args.__dict__[key] = f()
return args

View file

@ -6,23 +6,10 @@ import json
from trap.tracker import DETECTORS, TRACKER_BYTETRACK, TRACKERS
from trap.frame_emitter import Camera
from trap.base import CameraAction, HomographyAction
from trap.base import CameraAction, HomographyAction, LambdaParser
from pyparsing import Optional
from trap.frame_emitter import UrlOrPath
class LambdaParser(argparse.ArgumentParser):
"""Execute lambda functions
"""
def parse_args(self, args=None, namespace=None):
args = super().parse_args(args, namespace)
for key in vars(args):
f = args.__dict__[key]
if type(f) == types.LambdaType:
print(f'Getting default value for {key}')
args.__dict__[key] = f()
return args
parser = LambdaParser()
# parser.parse_args()
@ -261,7 +248,6 @@ frame_emitter_parser.add_argument("--video-loop",
#TODO: camera as source
# Tracker
tracker_parser.add_argument("--camera-fps",
help="Camera FPS",
@ -277,6 +263,8 @@ tracker_parser.add_argument("--calibration",
# type=Path,
default=None,
action=CameraAction)
# Tracker
tracker_parser.add_argument("--save-for-training",
help="Specify the path in which to save",
type=Path,

View file

@ -1,6 +1,8 @@
import collections
from gc import is_finalized
import logging
import statistics
import threading
import time
from typing import MutableSequence
import zmq
@ -8,20 +10,59 @@ import zmq
logger = logging.getLogger('counter')
class CounterSender:
def __init__(self, address = "ipc:///tmp/trap-counters"):
def __init__(self, address = "ipc:///tmp/trap-counters2"):
# self.name = name
self.context = zmq.Context()
self.sock = self.context.socket(zmq.PUB)
# self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# self.sock.sndhwm = 1
self.sock.bind(address)
self.sock.connect(address)
def set(self, name:str, value:float):
try:
self.sock.send_multipart([name.encode('utf8'), str(value).encode("utf8")], flags=zmq.NOBLOCK)
# we cannot use send_multipart in combination with conflate
self.sock.send_pyobj([name, value], flags=zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.warning(f"No space in que to count {name} as {value}")
class CounterFpsSender():
def __init__(self, name:str , sender: CounterSender):
self.name = name
self.sender = sender
self.tocs: MutableSequence[(float, int)] = collections.deque(maxlen=5)
self.iterations: int = 0
# threading.Event.wait()
# TODO thread to daeomic loop so it automatically stops
self.thread = threading.Thread(target=self.interval, daemon=True)
self.is_finished = threading.Event()
def tick(self):
self.iterations += 1
self.snapshot()
def snapshot(self):
self.tocs.append((time.perf_counter(), self.iterations))
self.sender.set(self.name, self.fps)
@property
def fps(self):
if len(self.tocs) < 2:
return 0
dt = self.tocs[-1][0] - self.tocs[0][0]
di = self.tocs[-1][1] - self.tocs[0][1]
return di/dt
def interval(self):
while True:
self.is_finished.wait(.5)
if self.is_finished.is_set():
break
self.snapshot()
# timer = threading.Timer(.5, self.interval)
# timer.start()
class CounterLog():
def __init__(self, history = 20):
self.history: MutableSequence[(float, float)] = collections.deque(maxlen=history)
@ -30,29 +71,37 @@ class CounterLog():
self.history.append((time.perf_counter(), value))
def value(self):
if not len(self.history):
return None
return self.history[-1][1]
def has_value(self):
if not len(self.history):
return False
if (time.perf_counter() - self.history[-1][0]) > 4:
# no update in 4s: very slow. Dead thread?
return False
return True
def avg(self):
if not len(self.history):
return 0.
return statistics.fmean([h[1] for h in self.history])
class CounterListerner():
def __init__(self, address = "ipc:///tmp/trap-counters"):
def __init__(self, address = "ipc:///tmp/trap-counters2"):
self.context = zmq.Context()
self.sock = self.context.socket(zmq.SUB)
self.sock.connect(address)
self.sock.bind(address)
self.sock.subscribe( b'')
self.values: collections.defaultdict[str, CounterLog] = collections.defaultdict(lambda: CounterLog())
def snapshot(self):
messages = []
while self.sock.poll(0) == zmq.POLLIN:
name, value = self.sock.recv_multipart()
name, value = name.decode('utf8'),float(value.decode('utf8'))
self.values[name].add(value)
msg = self.sock.recv_pyobj()
# print(msg)
name, value = msg
# name, value = name.decode('utf8'),float(value.decode('utf8'))
self.values[name].add(float(value))
def get_latest(self):
@ -60,7 +109,7 @@ class CounterListerner():
return self.values
def to_string(self):
strs = [f"{k}: {v.value()} ({v.avg()})" for (k,v) in self.values.items()]
strs = [(f"{k}: {v.value():.2f} ({v.avg():.2f})" if v.has_value() else f"{k}: --") for (k,v) in self.values.items()]
return " ".join(strs)

View file

@ -1,74 +1,44 @@
# used for "Forward Referencing of type annotations"
from __future__ import annotations
import time
import ffmpeg
from argparse import Namespace
import datetime
import logging
from multiprocessing import Event
import time
from argparse import ArgumentParser, Namespace
from multiprocessing.synchronize import Event as BaseEvent
from typing import Dict
import cv2
import ffmpeg
import numpy as np
import json
import pyglet
import pyglet.event
import zmq
import tempfile
from pathlib import Path
import shutil
import math
from typing import Dict, Iterable, Optional
from pyglet import shapes
from PIL import Image
from trap.counter import CounterListerner
from trap.frame_emitter import DetectionState, Frame, Track, Camera
from trap.frame_emitter import Frame, Track
from trap.node import Node
from trap.preview_renderer import FrameWriter
from trap.tools import draw_track, draw_track_predictions, draw_track_projected, draw_trackjectron_history, to_point
from trap.utils import convert_world_points_to_img_points, convert_world_space_to_img_space
from trap.tools import draw_track_predictions, draw_track_projected, to_point
from trap.utils import convert_world_points_to_img_points
logger = logging.getLogger("trap.simple_renderer")
class CvRenderer:
def __init__(self, config: Namespace, is_running: BaseEvent):
self.config = config
self.is_running = is_running
class CvRenderer(Node):
def setup(self):
self.prediction_sock = self.sub(self.config.zmq_prediction_addr)
self.tracker_sock = self.sub(self.config.zmq_trajectory_addr)
self.frame_sock = self.sub(self.config.zmq_frame_addr)
self.counter_listener = CounterListerner()
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
# self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_sock.connect(config.zmq_prediction_addr)
self.tracker_sock = context.socket(zmq.SUB)
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.tracker_sock.connect(config.zmq_trajectory_addr)
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr)
self.H = self.config.H
self.inv_H = np.linalg.pinv(self.H)
# self.H = self.config.H
# self.inv_H = np.linalg.pinv(self.H)
# TODO: get FPS from frame_emitter
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
self.fps = 60
self.frame_size = (self.config.camera.projected_w,self.config.camera.projected_h)
self.hide_stats = False
self.frame_size = None # configure on first frame recv
# self.frame_size = (self.config.camera.projected_w,self.config.camera.projected_h)
self.out_writer = self.start_writer() if self.config.render_file else None
self.streaming_process = self.start_streaming() if self.config.render_url else None
@ -80,81 +50,6 @@ class CvRenderer:
self.tracks: Dict[str, Track] = {}
self.predictions: Dict[str, Track] = {}
# self.init_shapes()
# self.init_labels()
def init_shapes(self):
'''
Due to error when running headless, we need to configure options before extending the shapes class
'''
class GradientLine(shapes.Line):
def __init__(self, x, y, x2, y2, width=1, color1=[255,255,255], color2=[255,255,255], batch=None, group=None):
# print('colors!', colors)
# assert len(colors) == 6
r, g, b, *a = color1
self._rgba1 = (r, g, b, a[0] if a else 255)
r, g, b, *a = color2
self._rgba2 = (r, g, b, a[0] if a else 255)
# print('rgba', self._rgba)
super().__init__(x, y, x2, y2, width, color1, batch=None, group=None)
# <pyglet.graphics.vertexdomain.VertexList
# pyglet.graphics.vertexdomain
# print(self._vertex_list)
def _create_vertex_list(self):
'''
copy of super()._create_vertex_list but with additional colors'''
self._vertex_list = self._group.program.vertex_list(
6, self._draw_mode, self._batch, self._group,
position=('f', self._get_vertices()),
colors=('Bn', self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1 ),
translation=('f', (self._x, self._y) * self._num_verts))
def _update_colors(self):
self._vertex_list.colors[:] = self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1
def color1(self, color):
r, g, b, *a = color
self._rgba1 = (r, g, b, a[0] if a else 255)
self._update_colors()
def color2(self, color):
r, g, b, *a = color
self._rgba2 = (r, g, b, a[0] if a else 255)
self._update_colors()
self.gradientLine = GradientLine
def init_labels(self):
base_color = (255,)*4
color_predictor = (255,255,0, 255)
color_info = (255,0, 255, 255)
color_tracker = (0,255, 255, 255)
options = []
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
options.append(f"{option}: {self.config.__dict__[option]}")
self.labels = {
'waiting': pyglet.text.Label("Waiting for prediction"),
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
}
def refresh_labels(self, dt: float):
"""Every frame"""
@ -173,130 +68,6 @@ class CvRenderer:
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
# self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
# cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
# if prediction_frame:
# # render Δt and Δ frames
# cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
# cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
# cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
# cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()]):.2f}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
# cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
# cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
# options = []
# for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
# options.append(f"{option}: {config.__dict__[option]}")
# cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
# cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
def check_frames(self, dt):
new_tracks = False
try:
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
if not self.first_time:
self.first_time = self.frame.time
img = cv2.GaussianBlur(self.frame.img, (15, 15), 0)
img = cv2.flip(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 0)
img = pyglet.image.ImageData(self.frame_size[0], self.frame_size[1], 'RGB', img.tobytes())
# don't draw in batch, so that it is the background
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
self.video_sprite.opacity = 100
except zmq.ZMQError as e:
# idx = frame.index if frame else "NONE"
# logger.debug(f"reuse video frame {idx}")
pass
try:
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
new_tracks = True
except zmq.ZMQError as e:
pass
try:
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
new_tracks = True
except zmq.ZMQError as e:
pass
def on_key_press(self, symbol, modifiers):
print('A key was pressed, use f to hide')
if symbol == ord('f'):
self.window.set_fullscreen(not self.window.fullscreen)
if symbol == ord('h'):
self.hide_stats = not self.hide_stats
def check_running(self, dt):
if not self.is_running.is_set():
self.window.close()
self.event_loop.exit()
def on_close(self):
self.is_running.clear()
def on_refresh(self, dt: float):
# update shapes
# self.bg =
for track_id, track in self.drawn_tracks.items():
track.update_drawn_positions(dt)
self.refresh_labels(dt)
# self.shape1 = shapes.Circle(700, 150, 100, color=(50, 0, 30), batch=self.batch_anim)
# self.shape3 = shapes.Circle(800, 150, 100, color=(100, 225, 30), batch=self.batch_anim)
pass
def on_draw(self):
self.window.clear()
self.batch_bg.draw()
for track in self.drawn_tracks.values():
for shape in track.shapes:
shape.draw() # for some reason the batches don't work
for track in self.drawn_tracks.values():
for shapes in track.pred_shapes:
for shape in shapes:
shape.draw()
# self.batch_anim.draw()
self.batch_overlay.draw()
# pyglet.graphics.draw(3, pyglet.gl.GL_LINE, ("v2i", (100,200, 600,800)), ('c3B', (255,255,255, 255,255,255)))
if not self.hide_stats:
self.fps_display.draw()
# if streaming, capture buffer and send
try:
if self.streaming_process or self.out_writer:
buf = pyglet.image.get_buffer_manager().get_color_buffer()
img_data = buf.get_image_data()
data = img_data.get_data() # alternative: .get_data("RGBA", image_data.pitch)
img = np.asanyarray(data).reshape((img_data.height, img_data.width, 4))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
img = np.flip(img, 0)
# img = cv2.flip(img, cv2.0)
# cv2.imshow('frame', img)
# cv2.waitKey(1)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
if self.out_writer:
self.out_writer.write(img)
except Exception as e:
logger.exception(e)
def start_writer(self):
if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist")
@ -305,16 +76,16 @@ class CvRenderer:
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
logger.info(f"Write to {filename}")
return FrameWriter(str(filename), self.fps, self.frame_size)
return FrameWriter(str(filename), self.fps, None)
fourcc = cv2.VideoWriter_fourcc(*'vp09')
# fourcc = cv2.VideoWriter_fourcc(*'vp09')
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
# return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
def start_streaming(self):
def start_streaming(self, frame_size=(1920,1080)):
return (
ffmpeg
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*frame_size))
.output(
self.config.render_url,
#codec = "copy", # use same codecs of the original video
@ -334,10 +105,7 @@ class CvRenderer:
)
# return process
def run(self, timer_counter):
def run(self):
frame = None
prediction_frame = None
tracker_frame = None
@ -348,14 +116,12 @@ class CvRenderer:
cv2.namedWindow("frame", cv2.WINDOW_NORMAL)
# https://gist.github.com/ronekko/dc3747211543165108b11073f929b85e
cv2.moveWindow("frame", 1920, -1)
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
bgsub = cv2.createBackgroundSubtractorMOG2(120, 50, detectShadows=True)
while self.is_running.is_set():
i+=1
with timer_counter.get_lock():
timer_counter.value+=1
if self.config.full_screen:
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
# bgsub = cv2.createBackgroundSubtractorMOG2(120, 50, detectShadows=True)
while self.run_loop():
i += 1
# zmq_ev = self.frame_sock.poll(timeout=2000)
# if not zmq_ev:
@ -397,7 +163,7 @@ class CvRenderer:
first_time = frame.time
# img = frame.img
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters, self.counter_listener,bgsub)
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters)
logger.debug(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
@ -431,6 +197,50 @@ class CvRenderer:
self.streaming_process.wait()
logger.info('stopped')
@classmethod
def arg_parser(cls):
render_parser = ArgumentParser()
render_parser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
render_parser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
render_parser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds_preds")
render_parser.add_argument("--render-file",
help="Render a video file previewing the prediction, and its delay compared to the current frame",
action='store_true')
render_parser.add_argument("--render-window",
help="Render a previewing to a window",
action='store_true')
render_parser.add_argument("--full-screen",
help="Set Window full screen",
action='store_true')
render_parser.add_argument("--render-clusters",
help="renders arrowd clusters instead of individual predictions",
action='store_true')
render_parser.add_argument("--render-url",
help="""Stream renderer on given URL. Two easy approaches:
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
""",
type=str,
default=None)
return render_parser
# colorset = itertools.product([0,255], repeat=3) # but remove white
# colorset = [(0, 0, 0),
# (0, 0, 255),
@ -460,7 +270,7 @@ def get_animation_position(track: Track, current_frame: Frame):
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True, counter_listener: CounterListerner|None = None, bg_subtractor = None) -> np.array:
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True) -> np.array:
scale = 100
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
@ -503,7 +313,7 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
else:
for track_id, track in tracks.items():
inv_H = np.linalg.pinv(tracker_frame.H)
draw_track_projected(img, track, int(track_id), config.camera, conversion)
draw_track_projected(img, track, int(track_id), frame.camera, conversion)
if not prediction_frame:
cv2.putText(img, f"Waiting for prediction...", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
@ -514,7 +324,7 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
# For debugging:
# draw_trackjectron_history(img, track, int(track.track_id), conversion)
anim_position = get_animation_position(track, frame)
draw_track_predictions(img, track, int(track.track_id)+1, config.camera, conversion, anim_position=anim_position, as_clusters=as_clusters)
draw_track_predictions(img, track, int(track.track_id)+1, frame.camera, conversion, anim_position=anim_position, as_clusters=as_clusters)
cv2.putText(img, f"{len(track.predictor_history) if track.predictor_history else 'none'}", to_point(track.history[0].get_foot_coords()), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,255), 1)
if prediction_frame.maps:
for i, m in enumerate(prediction_frame.maps):
@ -544,6 +354,8 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
cv2.putText(img, f"{frame.time - first_time: >10.2f}s", (150,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
cv2.putText(img, f"{frame.time - time.time():.2f}s", (250,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
options = []
if prediction_frame:
# render Δt and Δ frames
cv2.putText(img, f"{tracker_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, tracker_color, 1)
@ -555,18 +367,14 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (780,17), cv2.FONT_HERSHEY_PLAIN, 1, predictor_color, 1)
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (860,17), cv2.FONT_HERSHEY_PLAIN, 1, predictor_color, 1)
options = []
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
options.append(f"{option}: {config.__dict__[option]}")
for option, value in prediction_frame.log['predictor'].items():
options.append(f"{option}: {value}")
cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
# for i, (k, v) in enumerate(counter_listener.get_latest().items()):
# cv2.putText(img, f"{k} {v.value()}", (20,img.shape[0]-(40*i)-40), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
return img

View file

@ -1,75 +1,51 @@
from __future__ import annotations
from argparse import Namespace
from dataclasses import dataclass, field
import dataclasses
from enum import IntFlag
from itertools import cycle
import json
import logging
from multiprocessing import Event
import multiprocessing
from pathlib import Path
import pickle
import sys
import time
from typing import Iterable, List, Optional
import numpy as np
import cv2
import pandas as pd
import zmq
import os
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
from bytetracker.byte_tracker import STrack as ByteTrackTrack
from bytetracker.basetrack import TrackState as ByteTrackTrackState
from trajectron.environment import Environment, Node, Scene
from urllib.parse import urlparse
from argparse import ArgumentParser, Namespace
from multiprocessing import Event
from pathlib import Path
import zmq
from trap import node
from trap.base import *
from trap.base import LambdaParser
from trap.timer import Timer
from trap.utils import get_bins
from trap.utils import inv_lerp, lerp
from trap.video_sources import get_video_source
logger = logging.getLogger('trap.frame_emitter')
class FrameEmitter:
class FrameEmitter(node.Node):
'''
Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up
'''
def __init__(self, config: Namespace, is_running: Event) -> None:
self.config = config
self.is_running = is_running
def setup(self) -> None:
self.frame_sock = self.pub(self.config.zmq_frame_addr)
self.frame_noimg_sock = self.pub(self.config.zmq_frame_noimg_addr)
context = zmq.Context()
# TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859
self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_noimg_sock = context.socket(zmq.PUB)
self.frame_noimg_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_noimg_sock.bind(config.zmq_frame_noimg_addr)
logger.info(f"Connection socket {config.zmq_frame_addr}")
logger.info(f"Connection socket {config.zmq_frame_noimg_addr}")
logger.info(f"Connection socket {self.config.zmq_frame_addr}")
logger.info(f"Connection socket {self.config.zmq_frame_noimg_addr}")
self.video_srcs = self.config.video_src
def emit_video(self, timer_counter):
i = 0
source = get_video_source(self.video_srcs, self.config.camera, int(self.config.video_offset or 0), self.config.video_end, self.config.video_loop)
for i, img in enumerate(source):
with timer_counter.get_lock():
timer_counter.value += 1
def run(self):
offset = int(self.config.video_offset or 0)
source = get_video_source(self.video_srcs, self.config.camera, offset, self.config.video_end, self.config.video_loop)
video_gen = enumerate(source, start = offset)
while self.run_loop():
try:
i, img = next(video_gen)
except StopIteration as e:
logger.info("Video source ended")
break
frame = Frame(i, img=img, H=self.config.camera.H, camera=self.config.camera)
@ -78,73 +54,62 @@ class FrameEmitter:
self.frame_noimg_sock.send(pickle.dumps(frame.without_img()))
self.frame_sock.send(pickle.dumps(frame))
if not self.is_running.is_set():
# if not running, also break out of infinite generator loop
break
logger.info("Stopping")
@classmethod
def arg_parser(cls) -> ArgumentParser:
argparser = LambdaParser()
argparser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
argparser.add_argument('--zmq-frame-noimg-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame2")
argparser.add_argument("--video-src",
help="source video to track from can be either a relative or absolute path, or a url, like an RTSP resource, or use gige://RELATIVE_PATH_TO_GIGE_CONFIG_JSON",
type=UrlOrPath,
nargs='+',
default=lambda: [UrlOrPath(p) for p in Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')])
argparser.add_argument("--video-offset",
help="Start playback from given frame. Note that when src is an array, this applies to all videos individually.",
default=0,
type=int)
argparser.add_argument("--video-end",
help="End (or loop) playback at given frame.",
default=None,
type=int)
argparser.add_argument("--video-loop",
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
action='store_true')
argparser.add_argument("--camera-fps",
help="Camera FPS",
type=int,
default=12)
argparser.add_argument("--homography",
help="File with homography params [Deprecated]",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
action=HomographyAction)
argparser.add_argument("--calibration",
help="File with camera intrinsics and lens distortion params (calibration.json)",
# type=Path,
required=True,
# default=None,
action=CameraAction)
return argparser
def run_frame_emitter(config: Namespace, is_running: Event, timer_counter: int):
router = FrameEmitter(config, is_running)
router.emit_video(timer_counter)
is_running.clear()
def run():
# Frame emitter
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
argparser.add_argument('--zmq-frame-noimg-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame2")
argparser.add_argument("--video-src",
help="source video to track from can be either a relative or absolute path, or a url, like an RTSP resource",
type=UrlOrPath,
nargs='+',
default=lambda: [UrlOrPath(p) for p in Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')])
argparser.add_argument("--video-offset",
help="Start playback from given frame. Note that when src is an array, this applies to all videos individually.",
default=0,
type=int)
#TODO: camera as source
argparser.add_argument("--video-loop",
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
action='store_true')
#TODO: camera as source
# Tracker
argparser.add_argument("--camera-fps",
help="Camera FPS",
type=int,
default=12)
argparser.add_argument("--homography",
help="File with homography params",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
action=HomographyAction)
argparser.add_argument("--calibration",
help="File with camera intrinsics and lens distortion params (calibration.json)",
# type=Path,
required=True,
# default=None,
action=CameraAction)
config = argparser.parse_args()
is_running = multiprocessing.Event()
print(is_running.set())
timer_counter = Timer('frame_emitter')
router = FrameEmitter(config, is_running)
router.emit_video(timer_counter.iterations)
router.run(timer_counter)
is_running.clear()

View file

@ -1,18 +1,20 @@
import logging
import multiprocessing
from multiprocessing.synchronize import Event as BaseEvent
from argparse import Namespace
from argparse import ArgumentParser, Namespace
from typing import Optional
import zmq
from trap.counter import CounterFpsSender, CounterSender
from trap.timer import Timer
class Node():
def __init__(self, config: Namespace, is_running: BaseEvent, timer_counter: Timer):
def __init__(self, config: Namespace, is_running: BaseEvent, fps_counter: CounterFpsSender):
self.config = config
self.is_running = is_running
self.timer_counter = timer_counter
self.fps_counter = fps_counter
self.zmq_context = zmq.Context()
self.logger = self._logger()
@ -23,8 +25,9 @@ class Node():
return logging.getLogger(f"trap.{cls.__name__}")
def tick(self):
with self.timer_counter.get_lock():
self.timer_counter.value+=1
self.fps_counter.tick()
# with self.fps_counter.get_lock():
# self.fps_counter.value+=1
def setup(self):
raise RuntimeError("Not implemented setup()")
@ -32,6 +35,17 @@ class Node():
def run(self):
raise RuntimeError("Not implemented run()")
def run_loop(self):
"""Use in run(), to check if it should keep looping
Takes care of tick()'ing the iterations/second counter
"""
self.tick()
return self.is_running.is_set()
@classmethod
def arg_parser(cls) -> ArgumentParser:
raise RuntimeError("Not implemented arg_parser()")
def sub(self, addr: str):
"Default zmq sub configuration"
sock = self.zmq_context.socket(zmq.SUB)
@ -52,3 +66,14 @@ class Node():
instance = cls(config, is_running, timer_counter)
instance.run()
instance.logger.info("Stopping")
@classmethod
def parse_and_start(cls):
config = cls.arg_parser().parse_args()
is_running = multiprocessing.Event()
is_running.set()
statsender = CounterSender()
counter = CounterFpsSender(f"trap.{cls.__name__}", statsender)
# timer_counter = Timer(cls.__name__)
cls.start(config, is_running, counter)

View file

@ -1,33 +1,27 @@
# adapted from Trajectron++ online_server.py
from argparse import Namespace
import logging
from multiprocessing import Event, Queue
import os
import pickle
import sys
import time
import json
import traceback
import warnings
import pandas as pd
import torch
import dill
import random
import logging
import os
import pathlib
import numpy as np
from trajectron.environment.data_utils import derivative_of
from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
from trajectron.environment.node import Node
from trajectron.environment.node_type import NodeType
import matplotlib.pyplot as plt
import pickle
import random
import time
import warnings
from argparse import ArgumentParser, Namespace
from multiprocessing import Event
import dill
import numpy as np
import torch
import zmq
from trajectron.environment import Environment, Scene
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.utils import prediction_output_to_trajectories
from trap.frame_emitter import DataclassJSONEncoder, Frame
from trap.tracker import Track, Smoother
from trap.node import Node
from trap.tracker import Smoother
logger = logging.getLogger("trap.prediction")
@ -146,27 +140,18 @@ def offset_trajectron_dict(source, x, y):
source[t][node][:,1] += y
return source
class PredictionServer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
class PredictionServer(Node):
def setup(self):
if self.config.eval_device == 'cpu':
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
if self.config.smooth_predictions:
self.smoother = Smoother(window_len=12, convolution=True) # convolution seems fine for predictions
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
self.trajectory_socket = self.sub(self.config.zmq_trajectory_addr)
self.prediction_socket = self.pub(self.config.zmq_prediction_addr)
self.external_predictions = not self.config.zmq_prediction_addr.startswith("ipc://")
# print(self.prediction_socket)
def send_frame(self, frame: Frame):
if self.external_predictions:
@ -175,8 +160,7 @@ class PredictionServer:
else:
self.prediction_socket.send_pyobj(frame)
def run(self, timer_counter):
print(self.config)
def run(self):
if self.config.seed is not None:
random.seed(self.config.seed)
np.random.seed(self.config.seed)
@ -250,16 +234,8 @@ class PredictionServer:
trajectron.set_environment(online_env, init_timestep)
timestep = init_timestep + 1
prev_run_time = 0
while self.is_running.is_set():
timestep += 1
with timer_counter.get_lock():
timer_counter.value+=1
# this_run_time = time.time()
# logger.debug(f'test {prev_run_time - this_run_time}')
# time.sleep(max(0, prev_run_time - this_run_time + .5))
# prev_run_time = time.time()
while self.run_loop():
# TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns
@ -283,7 +259,6 @@ class PredictionServer:
if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else:
# print('await', self.config.zmq_trajectory_addr)
zmq_ev = self.trajectory_socket.poll(timeout=2000)
if not zmq_ev:
# on no data loop so that is_running is checked
@ -295,6 +270,12 @@ class PredictionServer:
# print('recv tracker frame')
frame: Frame = pickle.loads(data)
# add settings to log
frame.log['predictor'] = {}
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
frame.log['predictor'][option] = self.config.__dict__[option]
# print('indexrecv', [frame.tracks[t].frame_index for t in frame.tracks])
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data)
@ -490,6 +471,160 @@ class PredictionServer:
logger.info('Stopping')
@classmethod
def arg_parser(cls) -> ArgumentParser:
inference_parser = ArgumentParser()
inference_parser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
inference_parser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds_preds")
inference_parser.add_argument("--step-size",
# TODO)) Make dataset/model metadata
help="sample step size (should be the same as for data processing and augmentation)",
type=int,
default=1,
)
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
type=str, # TODO: make into Path
default='../Trajectron-plus-plus/experiments/trap/models/models_18_Oct_2023_19_56_22_virat_vel_ar3/')
# default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
inference_parser.add_argument("--conf",
help="path to json config file for hyperparameters, relative to model_dir",
type=str,
default='config.json')
# Model Parameters (hyperparameters)
inference_parser.add_argument("--offline_scene_graph",
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--dynamic_edges",
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--edge_state_combine_method",
help="the method to use for combining edges of the same type",
type=str,
default='sum')
inference_parser.add_argument("--edge_influence_combine_method",
help="the method to use for combining edge influences",
type=str,
default='attention')
inference_parser.add_argument('--edge_addition_filter',
nargs='+',
help="what scaling to use for edges as they're created",
type=float,
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
# and short edge addition, then you need to have a 0.0 at the
# beginning, e.g. [0.0, 1.0].
inference_parser.add_argument('--edge_removal_filter',
nargs='+',
help="what scaling to use for edges as they're removed",
type=float,
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
# the default, then you need to have a 0.0 at the end.
inference_parser.add_argument('--incl_robot_node',
help="whether to include a robot node in the graph or simply model all agents",
action='store_true')
inference_parser.add_argument('--map_encoding',
help="Whether to use map encoding or not",
action='store_true')
inference_parser.add_argument('--no_edge_encoding',
help="Whether to use neighbors edge encoding",
action='store_true')
inference_parser.add_argument('--batch_size',
help='training batch size',
type=int,
default=256)
inference_parser.add_argument('--k_eval',
help='how many samples to take during evaluation',
type=int,
default=25)
# Data Parameters
inference_parser.add_argument("--eval_data_dict",
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
type=str,
default='../Trajectron-plus-plus/experiments/processed/eth_test.pkl')
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=pathlib.Path,
default='./OUT/test_inference')
# inference_parser.add_argument('--device',
# help='what device to perform training on',
# type=str,
# default='cuda:0')
inference_parser.add_argument("--eval_device",
help="what device to use during inference",
type=str,
default="cpu")
inference_parser.add_argument('--seed',
help='manual seed to use, default is 123',
type=int,
default=123)
inference_parser.add_argument('--predict_training_data',
help='Ignore tracker and predict data from the training dataset',
action='store_true')
inference_parser.add_argument("--smooth-predictions",
help="Smooth the predicted tracks",
action='store_true')
inference_parser.add_argument('--prediction-horizon',
help='Trajectron.incremental_forward parameter',
type=int,
default=30)
inference_parser.add_argument('--num-samples',
help='Trajectron.incremental_forward parameter',
type=int,
default=5)
inference_parser.add_argument("--full-dist",
help="Trajectron.incremental_forward parameter",
action='store_true')
inference_parser.add_argument("--gmm-mode",
help="Trajectron.incremental_forward parameter",
type=bool,
default=True)
inference_parser.add_argument("--z-mode",
help="Trajectron.incremental_forward parameter",
action='store_true')
inference_parser.add_argument('--cm-to-m',
help="Correct for homography that is in cm (i.e. {x,y}/100). Should also be used when processing data",
action='store_true')
inference_parser.add_argument('--center-data',
help="Center data around cx and cy. Should also be used when processing data",
action='store_true')
return inference_parser
def run_prediction_server(config: Namespace, is_running: Event, timer_counter):

View file

@ -298,7 +298,7 @@ class FrameWriter:
framerate.
See https://video.stackexchange.com/questions/25811/ffmpeg-make-video-with-non-constant-framerate-from-image-filenames
"""
def __init__(self, filename: str, fps: float, frame_size: tuple) -> None:
def __init__(self, filename: str, fps: float, frame_size: Optional[tuple] = None) -> None:
self.filename = filename
self.fps = fps
self.frame_size = frame_size

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from argparse import ArgumentParser
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
@ -870,4 +871,21 @@ class Stage(Node):
# print(json.dumps(rl, cls=DataclassJSONEncoder))
@classmethod
def arg_parser(cls) -> ArgumentParser:
argparser = ArgumentParser()
argparser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
argparser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds_preds")
argparser.add_argument('--zmq-stage-addr',
help='Manually specity communication addr for the stage messages (the rendered lines)',
type=str,
default="tcp://0.0.0.0:99174")
return argparser

View file

@ -1,39 +1,40 @@
from argparse import Namespace
from collections import defaultdict
import argparse
import csv
from dataclasses import dataclass, field
import json
import logging
from math import nan
from multiprocessing import Event
import multiprocessing
from pathlib import Path
import pickle
import time
from typing import DefaultDict, Dict, Optional, List
from argparse import Namespace
from collections import defaultdict
from datetime import datetime, timedelta
from multiprocessing import Event
from pathlib import Path
from typing import DefaultDict, Dict, List, Optional
import cv2
import jsonlines
import numpy as np
import torch
import torchvision
import ultralytics
import zmq
import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights, FasterRCNN_ResNet50_FPN_V2_Weights, fasterrcnn_resnet50_fpn_v2
from deep_sort_realtime.deepsort_tracker import DeepSort
from torchvision.models import ResNet50_Weights
from bytetracker import BYTETracker
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deepsort_tracker import DeepSort
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
KeypointRCNN_ResNet50_FPN_Weights,
MaskRCNN_ResNet50_FPN_V2_Weights,
fasterrcnn_resnet50_fpn_v2,
keypointrcnn_resnet50_fpn,
maskrcnn_resnet50_fpn_v2)
from tsmoothie.smoother import ConvolutionSmoother, KalmanSmoother
from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult
from trap import timer
from trap.frame_emitter import Camera, DataclassJSONEncoder, DetectionState, Frame, Detection, Track
from bytetracker import BYTETracker
from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
import tsmoothie.smoother
from datetime import datetime, timedelta
from trap.frame_emitter import (Camera, DataclassJSONEncoder, Detection,
DetectionState, Frame, Track)
from trap.node import Node
# Detection = [int, int, int, int, float, int]
# Detections = [Detection]
@ -387,9 +388,8 @@ class ByteTrackWrapper(TrackerWrapper):
class Tracker:
def __init__(self, config: Namespace):
self.config = config
class Tracker(Node):
def setup(self):
# # TODO: config device
@ -453,6 +453,9 @@ class Tracker:
logger.info("Smoother Disabled (enable with --smooth-tracks)")
self.frame_sock = self.sub(self.config.zmq_frame_addr)
self.trajectory_socket = self.pub(self.config.zmq_trajectory_addr)
logger.debug("Set up tracker")
def track_frame(self, frame: Frame):
@ -474,42 +477,11 @@ class Tracker:
return detections
def track(self, is_running: Event, timer_counter: int = 0):
def run(self):
"""
Live tracking of frames coming in over zmq
"""
self.is_running = is_running
context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(self.config.zmq_frame_addr)
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket.bind(self.config.zmq_trajectory_addr)
prev_run_time = 0
# training_fp = None
# training_csv = None
# training_frames = 0
# if self.config.save_for_training is not None:
# if not isinstance(self.config.save_for_training, Path):
# raise ValueError("save-for-training should be a path")
# if not self.config.save_for_training.exists():
# logger.info(f"Making path for training data: {self.config.save_for_training}")
# self.config.save_for_training.mkdir(parents=True, exist_ok=False)
# else:
# logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
# training_fp = open(self.config.save_for_training / 'all.txt', 'w')
# # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
# training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
prev_frame_i = -1
with TrainingDataWriter(self.config.save_for_training) as writer:
@ -518,9 +490,6 @@ class Tracker:
w_time = None
displacement_filter = FinalDisplacementFilter(.8)
while self.is_running.is_set():
with timer_counter.get_lock():
timer_counter.value += 1
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
# so for now, timing should move to emitter
@ -532,10 +501,12 @@ class Tracker:
poll_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warning('skip poll after 2000ms')
logger.warning('no frame for 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
self.tick() # only tick if something is actually received
start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
@ -697,10 +668,41 @@ class Tracker:
"""
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
@classmethod
def arg_parser(cls):
argparser = argparse.ArgumentParser()
argparser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame")
argparser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
argparser.add_argument("--save-for-training",
help="Specify the path in which to save",
type=Path,
default=None)
argparser.add_argument("--detector",
help="Specify the detector to use",
type=str,
default=DETECTOR_YOLOv8,
choices=DETECTORS)
argparser.add_argument("--tracker",
help="Specify the detector to use",
type=str,
default=TRACKER_BYTETRACK,
choices=TRACKERS)
argparser.add_argument("--smooth-tracks",
help="Smooth the tracker tracks before sending them to the predictor",
action='store_true')
return argparser
def run_tracker(config: Namespace, is_running: Event, timer_counter):
router = Tracker(config)
router.track(is_running, timer_counter)
router.run(is_running, timer_counter)
def run():
# Frame emitter
@ -738,7 +740,7 @@ def run():
timer_counter = timer.Timer('frame_emitter')
router = Tracker(config)
router.track(is_running, timer_counter.iterations)
router.run(is_running, timer_counter.iterations)
is_running.clear()

14
uv.lock
View file

@ -2191,6 +2191,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995 },
]
[[package]]
name = "supervisor"
version = "4.2.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "setuptools" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ce/37/517989b05849dd6eaa76c148f24517544704895830a50289cbbf53c7efb9/supervisor-4.2.5.tar.gz", hash = "sha256:34761bae1a23c58192281a5115fb07fbf22c9b0133c08166beffc70fed3ebc12", size = 466073 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/7a/0ad3973941590c040475046fef37a2b08a76691e61aa59540828ee235a6e/supervisor-4.2.5-py2.py3-none-any.whl", hash = "sha256:2ecaede32fc25af814696374b79e42644ecaba5c09494c51016ffda9602d0f08", size = 319561 },
]
[[package]]
name = "tensorboard"
version = "2.19.0"
@ -2495,6 +2507,7 @@ dependencies = [
{ name = "setproctitle" },
{ name = "shapely" },
{ name = "simplification" },
{ name = "supervisor" },
{ name = "tensorboardx" },
{ name = "torch", version = "1.12.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
{ name = "torch", version = "1.12.1+cu113", source = { url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl" }, marker = "sys_platform == 'linux'" },
@ -2527,6 +2540,7 @@ requires-dist = [
{ name = "setproctitle", specifier = ">=1.3.3,<2" },
{ name = "shapely", specifier = ">=2.1" },
{ name = "simplification", specifier = ">=0.7.12" },
{ name = "supervisor", specifier = ">=4.2.5" },
{ name = "tensorboardx", specifier = ">=2.6.2.2,<3" },
{ name = "torch", marker = "python_full_version < '3.10' or python_full_version >= '4' or sys_platform != 'linux'", specifier = "==1.12.1" },
{ name = "torch", marker = "python_full_version >= '3.10' and python_full_version < '4' and sys_platform == 'linux'", url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl" },