Provide history to stage

This commit is contained in:
Ruben van de Ven 2025-10-30 15:45:29 +01:00
parent afe5accb9c
commit 1ac199732c
8 changed files with 141 additions and 21 deletions

View file

@ -39,6 +39,7 @@ dependencies = [
"svgpathtools>=1.7.1", "svgpathtools>=1.7.1",
"velodyne-decoder>=3.1.0", "velodyne-decoder>=3.1.0",
"open3d>=0.19.0", "open3d>=0.19.0",
"nptyping>=2.5.0",
] ]
[project.scripts] [project.scripts]

View file

@ -33,7 +33,8 @@ command=uv run trap_tracker --smooth-tracks
directory=%(here)s directory=%(here)s
[program:stage] [program:stage]
command=uv run trap_stage # command=uv run trap_stage
command=uv run trap_stage --verbose --camera-fps 12 --homography ../DATASETS/hof3/homography.json --calibration ../DATASETS/hof3/calibration.json --cache-path /tmp/history_cache-hof3.pcl --tracker-output-dir EXPERIMENTS/raw/hof3/
directory=%(here)s directory=%(here)s
[program:predictor] [program:predictor]

View file

@ -16,6 +16,7 @@ import cv2
from dataclasses import dataclass, field from dataclasses import dataclass, field
import dataclasses import dataclasses
from nptyping import Float64, NDArray, Shape
import numpy as np import numpy as np
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
@ -395,7 +396,7 @@ class Track:
def track_update_dt(self) -> float: def track_update_dt(self) -> float:
return time.time() - self.updated_at return time.time() - self.updated_at
def get_projected_history(self, H: Optional[cv2.Mat] = None, camera: Optional[DistortedCamera]= None) -> np.array: def get_projected_history(self, H: Optional[cv2.Mat] = None, camera: Optional[DistortedCamera]= None) -> NDArray[Shape["*, 2"], Float64]:
foot_coordinates = [d.get_foot_coords() for d in self.history] foot_coordinates = [d.get_foot_coords() for d in self.history]
# TODO)) Undistort points before perspective transform # TODO)) Undistort points before perspective transform
if len(foot_coordinates): if len(foot_coordinates):
@ -408,7 +409,7 @@ class Track:
else: else:
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H) coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
return coords[0] return coords[0]
return np.array([]) return np.empty(shape=(0,2)) #np.array([], shape)
def get_projected_history_as_dict(self, H, camera: Optional[DistortedCamera]= None) -> dict: def get_projected_history_as_dict(self, H, camera: Optional[DistortedCamera]= None) -> dict:
coords = self.get_projected_history(H, camera) coords = self.get_projected_history(H, camera)

View file

@ -786,14 +786,19 @@ class SegmentLine(LineAnimator):
self.anim_f = anim_f or partial(self.anim_arrive, length=self.length) self.anim_f = anim_f or partial(self.anim_arrive, length=self.length)
@classmethod @classmethod
def anim_arrive(cls, t: float, ls: shapely.geometry.LineString, length: float): def anim_arrive(cls, t: float, ls: shapely.geometry.LineString, length: float, reverse=True):
t = 1-t # reverse if reverse:
t = 1-t # reverse
start_pos = t * ls.length start_pos = t * ls.length
end_pos = start_pos + length end_pos = start_pos + length
return (start_pos, end_pos) return (start_pos, end_pos)
@classmethod @classmethod
def anim_grow(cls, t: float, ls: shapely.geometry.LineString, reverse=False): def anim_grow(cls, t: float, ls: shapely.geometry.LineString, reverse=False, in_and_out=False, max_len=None):
if in_and_out:
l = ls.length
offset = max_len if max_len else l
return (max((l+offset) * t - offset, 0), min((l+offset) * t, l))
if reverse: if reverse:
return (ls.length * t, ls.length) return (ls.length * t, ls.length)
else: else:

View file

@ -48,13 +48,16 @@ class Node():
self.tick() self.tick()
return self.is_running.is_set() return self.is_running.is_set()
def run_loop_capped_fps(self, max_fps: float): def run_loop_capped_fps(self, max_fps: float, warn_below_fps: float = 0.):
"""Use in run(), to check if it should keep looping """Use in run(), to check if it should keep looping
Takes care of tick()'ing the iterations/second counter Takes care of tick()'ing the iterations/second counter
""" """
now = time.perf_counter() now = time.perf_counter()
time_diff = (now - self._prev_loop_time) time_diff = (now - self._prev_loop_time)
if warn_below_fps > 0 and time_diff > 1/warn_below_fps:
self.logger.warning(f"Running below {warn_below_fps} FPS: measured {1/time_diff} FPS")
if time_diff < 1/max_fps: if time_diff < 1/max_fps:
# print(f"sleep {1/max_fps - time_diff}") # print(f"sleep {1/max_fps - time_diff}")
time.sleep(1/max_fps - time_diff) time.sleep(1/max_fps - time_diff)

View file

@ -1,20 +1,26 @@
from __future__ import annotations
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial from functools import partial
import logging import logging
from math import inf
from pathlib import Path
import time import time
import threading import threading
from typing import Dict, List, Optional, Type, TypeVar from typing import Dict, Generator, List, Optional, Type, TypeVar
import numpy as np
import zmq import zmq
from trap.anomaly import DiffSegment, calc_anomaly, calculate_loitering_scores from trap.anomaly import DiffSegment, calc_anomaly, calculate_loitering_scores
from trap.base import DataclassJSONEncoder, Frame, ProjectedTrack, Track from trap.base import CameraAction, DataclassJSONEncoder, Frame, HomographyAction, ProjectedTrack, Track
from trap.counter import CounterSender from trap.counter import CounterSender
from trap.lines import AppendableLine, AppendableLineAnimator, Coordinate, CropLine, DashedLine, DeltaT, FadeOutJitterLine, FadeOutLine, FadedTailLine, LineAnimationStack, LineAnimator, NoiseLine, RenderableLayers, RenderableLine, RenderableLines, SegmentLine, SimplifyMethod, SrgbaColor, StaticLine, load_lines_from_svg from trap.lines import AppendableLine, AppendableLineAnimator, Coordinate, CropLine, DashedLine, DeltaT, FadeOutJitterLine, FadeOutLine, FadedTailLine, LineAnimationStack, LineAnimator, NoiseLine, RenderableLayers, RenderableLine, RenderableLines, SegmentLine, SimplifyMethod, SrgbaColor, StaticLine, load_lines_from_svg
from trap.node import Node from trap.node import Node
from trap.track_history import TrackHistory
logger = logging.getLogger('trap.stage') logger = logging.getLogger('trap.stage')
@ -255,6 +261,7 @@ class DrawnScenario(Scenario):
def __init__(self, track_id): def __init__(self, track_id):
super().__init__(track_id) super().__init__(track_id)
self.last_update_t = time.perf_counter() self.last_update_t = time.perf_counter()
self.active_ptrack: Optional[ProjectedTrack] = None
history_color = SrgbaColor(1.,0.,1.,1.) history_color = SrgbaColor(1.,0.,1.,1.)
history = StaticLine([], history_color) history = StaticLine([], history_color)
@ -266,17 +273,29 @@ class DrawnScenario(Scenario):
self.line_history.add(NoiseLine(self.line_history.tail, amplitude=0, t_factor=.3)) self.line_history.add(NoiseLine(self.line_history.tail, amplitude=0, t_factor=.3))
self.line_history.add(FadeOutJitterLine(self.line_history.tail, frequency=5, t_factor=.5)) self.line_history.add(FadeOutJitterLine(self.line_history.tail, frequency=5, t_factor=.5))
self.active_ptrack: Optional[ProjectedTrack] = None
self.prediction_color = SrgbaColor(0,1,0,1) self.prediction_color = SrgbaColor(0,1,0,1)
self.line_prediction = LineAnimationStack(StaticLine([], self.prediction_color)) self.line_prediction = LineAnimationStack(StaticLine([], self.prediction_color))
self.line_prediction.add(SegmentLine(self.line_prediction.tail, duration=.5)) self.line_prediction.add(SegmentLine(self.line_prediction.tail, duration=.5))
self.line_prediction.add(DashedLine(self.line_prediction.tail, t_factor=4, loop_offset=True)) self.line_prediction.add(DashedLine(self.line_prediction.tail, t_factor=4, loop_offset=True))
self.line_prediction.get(DashedLine).skip = True self.line_prediction.get(DashedLine).skip = True
self.line_prediction.add(FadeOutLine(self.line_prediction.tail)) self.line_prediction.add(FadeOutLine(self.line_prediction.tail))
# when rendering tracks from others similar/close to the current one
self.others_color = SrgbaColor(1,1,0,1)
self.line_others = LineAnimationStack(StaticLine([], self.others_color))
self.line_others.add(SegmentLine(self.line_others.tail, duration=3, anim_f=partial(SegmentLine.anim_grow, in_and_out=True, max_len=8)))
# self.line_others.add(DashedLine(self.line_others.tail, t_factor=4, loop_offset=True))
# self.line_others.get(DashedLine).skip = True
self.line_others.add(FadeOutLine(self.line_others.tail))
self.line_others.get(FadeOutLine).set_alpha(0)
self.tracks_to_self: Optional[Generator] = None
self.tracks_to_self_pos = None
self.tracks_to_self_fetched_at = None
# self.line_prediction_drawn = self.line_prediction_faded # self.line_prediction_drawn = self.line_prediction_faded
def update(self): def update(self, stage: Stage):
super().update() super().update()
if self.track: if self.track:
self.line_history.root.points = self.track.projected_history self.line_history.root.points = self.track.projected_history
@ -307,7 +326,7 @@ class DrawnScenario(Scenario):
# self.line_prediction.get(SegmentLine).skip = True # self.line_prediction.get(SegmentLine).skip = True
self.line_prediction.get(DashedLine).skip = False self.line_prediction.get(DashedLine).skip = False
self.line_prediction.start() self.line_prediction.start()
else: elif self.line_prediction.get(SegmentLine).duration != 2: # hack to only play once
self.line_prediction.get(SegmentLine).anim_f = partial(SegmentLine.anim_grow, reverse=True) self.line_prediction.get(SegmentLine).anim_f = partial(SegmentLine.anim_grow, reverse=True)
self.line_prediction.get(SegmentLine).duration = 2 self.line_prediction.get(SegmentLine).duration = 2
self.line_prediction.get(SegmentLine).start() self.line_prediction.get(SegmentLine).start()
@ -319,16 +338,49 @@ class DrawnScenario(Scenario):
self.line_prediction.root.points = self.active_ptrack._track.predictions[0] self.line_prediction.root.points = self.active_ptrack._track.predictions[0]
if self.scene is ScenarioScene.LOITERING: # special case: LOITERING
# special case: PLAY if self.scene is ScenarioScene.LOITERING or self.state_change_at:
# logger.info('loitering')
transition = min(1, (time.time() - self.state_change_at)/1.4) transition = min(1, (time.time() - self.state_change_at)/1.4)
# TODO: transition fade, using to_alpha(), so it can fade back in again: # TODO: transition fade, using to_alpha(), so it can fade back in again:
self.line_history.get(FadeOutJitterLine).set_alpha(1 - transition) self.line_history.get(FadeOutJitterLine).set_alpha(1 - transition)
if transition > .999: self.line_prediction.get(FadeOutLine).set_alpha(1 - transition)
current_position = self.track.projected_history[-1]
current_position_rounded = np.round(current_position*2) # cache per 1/2 meter
time_diff = inf if not self.tracks_to_self_fetched_at else time.perf_counter() - self.tracks_to_self_fetched_at
# print(transition > .999, self.is_running, current_position_rounded, time_diff)
if transition > .999 and self.is_running and not all(self.tracks_to_self_pos == current_position_rounded) and time_diff > 5: # only do these expensive calls when running
logger.info(f"Fetch similar tracks for {self.track_id}")
t = time.perf_counter()
self.tracks_to_self_pos = current_position_rounded
self.tracks_to_self_fetched_at = time.perf_counter()
# fetch lines nearby # fetch lines nearby
pass track_ids = stage.history.get_nearest_tracks(current_position, 15)
self.track_ids_to_self = iter(track_ids)
self.tracks_to_self = stage.history.ids_as_trajectory(track_ids)
print(time.perf_counter() - t, "fetch delya")
if self.tracks_to_self and self.line_others.is_ready():
current_history_id = next(self.track_ids_to_self)
current_history = next(self.tracks_to_self)
logger.info(f"play history item: {current_history_id}")
self.line_others.get(FadeOutLine).set_alpha(1)
self.line_others.root.points = current_history
# print(self.line_others.root.points)
self.line_others.start()
# special case: PLAY
elif self.scene is ScenarioScene.PLAY: elif self.scene is ScenarioScene.PLAY:
# special case: PLAY
pass pass
# if self.scene is ScenarioScene.CORRECTED_PREDICTION: # if self.scene is ScenarioScene.CORRECTED_PREDICTION:
# self.line_prediction.get(DashedLine).skip = False # self.line_prediction.get(DashedLine).skip = False
@ -355,13 +407,15 @@ class DrawnScenario(Scenario):
history_line = self.line_history.as_renderable_line(dt) history_line = self.line_history.as_renderable_line(dt)
prediction_line = self.line_prediction.as_renderable_line(dt) prediction_line = self.line_prediction.as_renderable_line(dt)
others_line = self.line_others.as_renderable_line(dt)
# print(history_line) # print(history_line)
# print(self.track_id, len(self.line_history.points), len(history_line)) # print(self.track_id, len(self.line_history.points), len(history_line))
return RenderableLines([ return RenderableLines([
history_line, history_line,
prediction_line prediction_line,
others_line
]) ])
@ -384,15 +438,23 @@ class Stage(Node):
if self.config.debug_map: if self.config.debug_map:
debug_color = SrgbaColor(0.,0.,1.,1.) debug_color = SrgbaColor(0.,0.,1.,1.)
self.debug_lines = RenderableLines(load_lines_from_svg(self.config.debug_map, 100, debug_color)) self.debug_lines = RenderableLines(load_lines_from_svg(self.config.debug_map, 100, debug_color))
self.history = TrackHistory(self.config.tracker_output_dir, self.config.camera, self.config.cache_path)
def run(self): def run(self):
while self.run_loop_capped_fps(self.FPS): while self.run_loop_capped_fps(self.FPS, warn_below_fps=10):
dt = max(1/ self.FPS, self.dt_since_last_tick) # never dt of 0 dt = max(1/ self.FPS, self.dt_since_last_tick) # never dt of 0
# t1 = time.perf_counter()
self.loop_receive() self.loop_receive()
# t2 = time.perf_counter()
self.loop_update_scenarios() self.loop_update_scenarios()
# t3 = time.perf_counter()
self.loop_render(dt) self.loop_render(dt)
# t4 = time.perf_counter()
# print(t2-t1, t3-t2, t4-t3)
def loop_receive(self): def loop_receive(self):
@ -421,7 +483,7 @@ class Stage(Node):
"""Update active scenarios and handle pauses/completions.""" """Update active scenarios and handle pauses/completions."""
# 1) process timestep for all scenarios # 1) process timestep for all scenarios
for s in self.scenarios.values(): for s in self.scenarios.values():
s.update() s.update(self)
# 2) Remove stale tracks and take-overs # 2) Remove stale tracks and take-overs
@ -471,6 +533,7 @@ class Stage(Node):
"""Draw all active scenarios onto the canvas.""" """Draw all active scenarios onto the canvas."""
lines = RenderableLines([]) lines = RenderableLines([])
# TODO: sometimes very slow!
for scenario in self.active_scenarios: for scenario in self.active_scenarios:
lines.append_lines(scenario.to_renderable_lines(dt)) lines.append_lines(scenario.to_renderable_lines(dt))
@ -516,6 +579,38 @@ class Stage(Node):
help='Maximum number of active scenarios that can be drawn at once (to not overlod the laser)', help='Maximum number of active scenarios that can be drawn at once (to not overlod the laser)',
type=int, type=int,
default=2) default=2)
# TODO: this should be subsumed to some sort of Track Dataset loader
historyargs = argparser.add_argument_group("Track History Loader")
historyargs.add_argument("--camera-fps",
help="Camera FPS",
type=int,
default=12)
historyargs.add_argument("--homography",
help="File with homography params [Deprecated]",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
action=HomographyAction)
historyargs.add_argument("--calibration",
help="File with camera intrinsics and lens distortion params (calibration.json)",
# type=Path,
required=True,
# default=None,
action=CameraAction)
historyargs.add_argument("--cache-path",
help="Where to cache the Track History dataset",
type=Path,
required=True,
)
historyargs.add_argument("--tracker-output-dir",
help="Directory for the track reader (e.g. EXPERIMENT/raw/_name_)",
type=Path,
required=True,
)
return argparser return argparser

View file

@ -158,7 +158,7 @@ class TrackReader:
def __len__(self): def __len__(self):
return len(self._tracks) return len(self._tracks)
def get(self, track_id): def get(self, track_id) -> Track:
return self._tracks[track_id] return self._tracks[track_id]
# detection_values = self._tracks[track_id] # detection_values = self._tracks[track_id]
# history = [] # history = []

14
uv.lock
View file

@ -1437,6 +1437,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/33/bd5b9137445ea4b680023eb0469b2bb969d61303dedb2aac6560ff3d14a1/notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef", size = 13307 }, { url = "https://files.pythonhosted.org/packages/f9/33/bd5b9137445ea4b680023eb0469b2bb969d61303dedb2aac6560ff3d14a1/notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef", size = 13307 },
] ]
[[package]]
name = "nptyping"
version = "2.5.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/b7/ffe533358c32506b1708feec0fb04ba0a35a959a94163fff5333671909da/nptyping-2.5.0.tar.gz", hash = "sha256:e3d35b53af967e6fb407c3016ff9abae954d3a0568f7cc13a461084224e8e20a", size = 71623 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/28/92edc05378175de13a3d4986cee7531853634a22b7e5e21a988fa84fde3f/nptyping-2.5.0-py3-none-any.whl", hash = "sha256:764e51836faae33a7ae2e928af574cfb701355647accadcc89f2ad793630b7c8", size = 37602 },
]
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.26.4" version = "1.26.4"
@ -2703,6 +2715,7 @@ dependencies = [
{ name = "ipywidgets" }, { name = "ipywidgets" },
{ name = "jsonlines" }, { name = "jsonlines" },
{ name = "noise" }, { name = "noise" },
{ name = "nptyping" },
{ name = "open3d" }, { name = "open3d" },
{ name = "opencv-python" }, { name = "opencv-python" },
{ name = "pandas-helper-calc" }, { name = "pandas-helper-calc" },
@ -2741,6 +2754,7 @@ requires-dist = [
{ name = "ipywidgets", specifier = ">=8.1.5,<9" }, { name = "ipywidgets", specifier = ">=8.1.5,<9" },
{ name = "jsonlines", specifier = ">=4.0.0,<5" }, { name = "jsonlines", specifier = ">=4.0.0,<5" },
{ name = "noise", specifier = ">=1.2.2" }, { name = "noise", specifier = ">=1.2.2" },
{ name = "nptyping", specifier = ">=2.5.0" },
{ name = "open3d", specifier = ">=0.19.0" }, { name = "open3d", specifier = ">=0.19.0" },
{ name = "opencv-python", path = "opencv_python-4.10.0.84-cp310-cp310-linux_x86_64.whl" }, { name = "opencv-python", path = "opencv_python-4.10.0.84-cp310-cp310-linux_x86_64.whl" },
{ name = "pandas-helper-calc", git = "https://github.com/scls19fr/pandas-helper-calc" }, { name = "pandas-helper-calc", git = "https://github.com/scls19fr/pandas-helper-calc" },