trap/trap/stage.py
2025-10-29 12:05:18 +01:00

521 lines
20 KiB
Python

from argparse import ArgumentParser
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from functools import partial
import logging
import time
import threading
from typing import Dict, List, Optional, Type, TypeVar
import zmq
from trap.anomaly import DiffSegment, calc_anomaly, calculate_loitering_scores
from trap.base import DataclassJSONEncoder, Frame, ProjectedTrack, Track
from trap.counter import CounterSender
from trap.lines import AppendableLine, AppendableLineAnimator, Coordinate, CropLine, DashedLine, DeltaT, FadeOutJitterLine, FadeOutLine, FadedTailLine, LineAnimationStack, LineAnimator, NoiseLine, RenderableLayers, RenderableLine, RenderableLines, SegmentLine, SimplifyMethod, SrgbaColor, StaticLine, load_lines_from_svg
from trap.node import Node
logger = logging.getLogger('trap.stage')
OPTION_RENDER_DEBUG = False
OPTION_POSITION_MARKER = False
OPTION_GROW_ANOMALY_CIRCLE = False
# OPTION_RENDER_DIFF_SEGMENT = True
OPTION_TRACK_NOISE = False
TRACK_ASSUMED_FPS = 12
TAKEOVER_FADEOUT = 3
LOST_FADEOUT = 2 # seconds
PREDICTION_INTERVAL: float|None = 20 # frames
PREDICTION_FADE_IN: float = 3
PREDICTION_FADE_SLOPE: float = -10
PREDICTION_FADE_AFTER_DURATION: float = 8 # seconds
PREDICTION_END_FADE = 2 #frames
# TRACK_MAX_POINTS = 100
TRACK_FADE_AFTER_DURATION = 15. # seconds
TRACK_END_FADE = 30 # points
TRACK_FADE_ASSUME_FPS = TRACK_ASSUMED_FPS
# LOITERING_WINDOW = 8 * TRACK_ASSUMED_FPS
# LOITERING_DISTANCE = 1 # meter diff in LOITERING_WINDOW time
# LOITERING_MEDIAN_FILTER = TRACK_ASSUMED_FPS // 3 # frames: smooth out velocity over n frames
LOITERING_VELOCITY_TRESHOLD = .5 # m/s
LOITERING_DURATION_TO_LINGER = TRACK_ASSUMED_FPS * 1 # start counting as lingering after this many frames
LOITERING_LINGER_FACTOR = TRACK_ASSUMED_FPS * 4 # number of frames to reach loitering score of 1 (+LOITERING_DURATION_TO_LINGER)
class DefaultDictKeyed(dict):
def __init__(self, factory):
self.factory = factory
def __missing__(self, key):
self[key] = self.factory(key)
return self[key]
@dataclass
class SceneInfo:
priority: int
description: str = ""
takeover_possible: bool = False # whether to allow for other scenarios to steal the stage
class ScenarioScene(Enum):
DETECTED = SceneInfo(4, "First detection")
SUBSTANTIAL = SceneInfo(6, "Multiple detections")
FIRST_PREDICTION = SceneInfo(10, "Prediction is ready")
CORRECTED_PREDICTION = SceneInfo(11, "Multiple predictions")
LOITERING = SceneInfo(7, "Foundto be loitering", takeover_possible=True)
PLAY = SceneInfo(7, description="After many predictions; just fooling around", takeover_possible=True)
LOST = SceneInfo(-1, description="Track lost", takeover_possible=True)
Time = float
class Scenario:
def __init__(self, track_id):
self.track_id = track_id
self.scene: ScenarioScene = ScenarioScene.DETECTED
self.start_time = 0.
self.current_time = 0
self.take_over_at: Optional[Time] = None
self.track: Optional[ProjectedTrack] = None
self.prediction_tracks: List[ProjectedTrack] = []
self._last_diff_frame_idx: Optional[int] = 0
self.prediction_diffs: List[DiffSegment] = []
self.state_change_at = None
self.is_running = False
logger.info(f"Found {self.track_id}: {self.scene.name}")
def start(self):
# change when visible
logger.info(f"Start {self.track_id}: {self.scene.name}")
self.is_running = True
def track_age(self):
if not self.track:
return 0
return time.time() - self.track.updated_at
def take_over(self):
if self.take_over_at:
return
self.take_over_at = time.time()
def taken_over(self):
self.is_running = False
self.take_over_at = None
def takenover_for(self):
if self.take_over_at:
return time.time() - self.take_over_at
return None
def takeover_factor(self):
l = self.takenover_for()
if not l:
return 0
return l/TAKEOVER_FADEOUT
def lost_for(self):
if self.scene is ScenarioScene.LOST:
return time.time() - self.state_change_at
return None
def lost_factor(self):
l = self.lost_for()
if not l:
return 0
return l/LOST_FADEOUT
def anomaly_factor(self):
return calc_anomaly(self.prediction_diffs, 10)
def deactivate(self):
self.take_over_at = None
def update(self):
"""Animation tick, check state."""
# 1) lost_score: unlike other states, this runs for each rendering pass to handle crashing tracker
self.check_lost()
def set_scene(self, scene: ScenarioScene):
if self.scene is scene:
return
logger.info(f"Changing scene for {self.track_id}: {self.scene.name} -> {scene.name}")
self.scene = scene
self.state_change_at = time.time()
def update_state(self):
self.check_lost() or self.check_loitering() or self.check_track()
def check_lost(self):
if self.track and (self.track.lost or self.track.updated_at < time.time() - 5):
self.set_scene(ScenarioScene.LOST)
return True
return False
def check_loitering(self):
scores = [s for s in calculate_loitering_scores(self.track, LOITERING_DURATION_TO_LINGER, LOITERING_LINGER_FACTOR, LOITERING_VELOCITY_TRESHOLD/TRACK_ASSUMED_FPS, 150)]
if scores[-1] > .99:
self.set_scene(ScenarioScene.LOITERING)
return True
return False
def check_track(self):
predictions = len(self.prediction_tracks)
if predictions == 1:
self.set_scene(ScenarioScene.FIRST_PREDICTION)
return True
if predictions > 10:
self.set_scene(ScenarioScene.PLAY)
return True
if predictions:
self.set_scene(ScenarioScene.CORRECTED_PREDICTION)
return True
if self.track:
if len(self.track.projected_history) > TRACK_ASSUMED_FPS * 3:
self.set_scene(ScenarioScene.SUBSTANTIAL)
else:
self.set_scene(ScenarioScene.DETECTED)
return True
return False
# the tracker track: replace
def recv_track(self, track: ProjectedTrack):
if self.track and self.track.created_at > track.created_at:
# ignore old track
return
self.track = track
self.update_prediction_diff()
self.update_state()
def update_prediction_diff(self):
"""
gather the diffs of the trajectory with the most recent prediction
"""
if len(self.prediction_diffs) == 0:
return
self.prediction_diffs[-1].update_track(self.track)
# receive new predictions: accumulate
def recv_prediction(self, track: ProjectedTrack):
if not self.track:
# in case of the unlikely event that prediction was received sooner
self.recv_track(track)
if PREDICTION_INTERVAL is not None and len(self.prediction_tracks) and (track.frame_index - self.prediction_tracks[-1].frame_index) < PREDICTION_INTERVAL:
# just drop tracks if the predictions come to quick
return
if track._track.predictions is None or not len(track._track.predictions):
# don't count to predictions if no prediction is set of given track (e.g. young tracks, that are still passed by the predictor)
return
self.prediction_tracks.append(track)
if len(self.prediction_diffs):
self.prediction_diffs[-1].finish() # existing diffing can end
# and create a new one
self.prediction_diffs.append(DiffSegment(track))
# self.prediction_diffs.append(DiffSegmentScan(track))
self.update_state()
class DrawnScenario(Scenario):
"""
Scenario contains the controls (scene, target positions)
DrawnScenario class does the actual drawing of points incl. transitions
This distinction is only for ordering the code
"""
MAX_HISTORY = 300 # points of history of trajectory to display (preventing too long lines)
CUT_GAP = 5 # when adding a new prediction, keep the existing prediction until that point + this CUT_GAP margin
def __init__(self, track_id):
super().__init__(track_id)
self.last_update_t = time.perf_counter()
history_color = SrgbaColor(1.,0.,1.,1.)
history = StaticLine([], history_color)
self.line_history = LineAnimationStack(history)
self.line_history.add(AppendableLineAnimator(self.line_history.tail, draw_decay_speed=25))
self.line_history.add(CropLine(self.line_history.tail, self.MAX_HISTORY))
self.line_history.add(FadedTailLine(self.line_history.tail, TRACK_FADE_AFTER_DURATION * TRACK_ASSUMED_FPS, TRACK_END_FADE))
self.line_history.add(NoiseLine(self.line_history.tail, amplitude=0, t_factor=.3))
self.line_history.add(FadeOutJitterLine(self.line_history.tail, frequency=5, t_factor=.5))
self.active_ptrack: Optional[ProjectedTrack] = None
self.prediction_color = SrgbaColor(0,1,0,1)
self.line_prediction = LineAnimationStack(StaticLine([], self.prediction_color))
self.line_prediction.add(SegmentLine(self.line_prediction.tail, duration=.5))
self.line_prediction.add(DashedLine(self.line_prediction.tail, t_factor=4, loop_offset=True))
self.line_prediction.get(DashedLine).skip = True
self.line_prediction.add(FadeOutLine(self.line_prediction.tail))
# self.line_prediction_drawn = self.line_prediction_faded
def update(self):
super().update()
if self.track:
self.line_history.root.points = self.track.projected_history
if len(self.prediction_tracks):
# TODO: only when animation is ready for it? or collect lines
if not self.active_ptrack:
self.active_ptrack = self.prediction_tracks[-1]
self.line_prediction.start() # reset positions
elif self.active_ptrack._track.updated_at < self.prediction_tracks[-1]._track.updated_at:
# switch only if drawing animation is ready
if self.line_prediction.is_ready():
self.active_ptrack = self.prediction_tracks[-1]
self.line_prediction.get(SegmentLine).anim_f = partial(SegmentLine.anim_arrive, length=.3)
self.line_prediction.get(SegmentLine).duration = .5
self.line_prediction.get(DashedLine).skip = True
# print('restart')
self.line_prediction.start() # reset positions
# print(self.line_prediction.get(SegmentLine).running_for())
else:
if self.line_prediction.is_ready():
# little hack: check is dashedline skips, to only run this once per animation:
if self.line_prediction.get(DashedLine).skip:
# no new yet, but ready with anim, start stage 2
self.line_prediction.get(SegmentLine).anim_f = partial(SegmentLine.anim_grow)
self.line_prediction.get(SegmentLine).duration = 1
# self.line_prediction.get(SegmentLine).skip = True
self.line_prediction.get(DashedLine).skip = False
self.line_prediction.start()
else:
self.line_prediction.get(SegmentLine).anim_f = partial(SegmentLine.anim_grow, reverse=True)
self.line_prediction.get(SegmentLine).duration = 2
self.line_prediction.get(SegmentLine).start()
# self.line_prediction_dashed.set_offset_t(self.active_ptrack._track.track_update_dt() * 4)
self.line_prediction.root.points = self.active_ptrack._track.predictions[0]
if self.scene is ScenarioScene.LOITERING:
# special case: PLAY
transition = min(1, (time.time() - self.state_change_at)/1.4)
# TODO: transition fade, using to_alpha(), so it can fade back in again:
self.line_history.get(FadeOutJitterLine).set_alpha(1 - transition)
if transition > .999:
# fetch lines nearby
pass
elif self.scene is ScenarioScene.PLAY:
# special case: PLAY
pass
# if self.scene is ScenarioScene.CORRECTED_PREDICTION:
# self.line_prediction.get(DashedLine).skip = False
def to_renderable_lines(self, dt: DeltaT) -> RenderableLines:
# each scene is handled differently:
# 1) history, fade out when lost
self.line_history.get(FadeOutJitterLine).set_alpha(1-self.lost_factor())
self.line_prediction.get(FadeOutLine).set_alpha(1-self.lost_factor())
self.line_history.get(NoiseLine).amplitude = self.lost_factor()
# fade out history after max duration, given in frames
track_age_in_frames = self.track_age() * TRACK_ASSUMED_FPS
self.line_history.get(FadedTailLine).set_frame_offset(track_age_in_frames)
# 2) also fade-out when moving into loitering mode.
# when fading out is done, start drawing historical data
history_line = self.line_history.as_renderable_line(dt)
prediction_line = self.line_prediction.as_renderable_line(dt)
# print(history_line)
# print(self.track_id, len(self.line_history.points), len(history_line))
return RenderableLines([
history_line,
prediction_line
])
class Stage(Node):
FPS = 60
def setup(self):
self.active_scenarios: List[DrawnScenario] = [] # List of currently running Scenario instances
self.scenarios: Dict[str, DrawnScenario] = DefaultDictKeyed(lambda key: DrawnScenario(key))
self.frame_noimg_sock = self.sub(self.config.zmq_frame_noimg_addr)
self.trajectory_sock = self.sub(self.config.zmq_trajectory_addr)
self.prediction_sock = self.sub(self.config.zmq_prediction_addr)
self.stage_sock = self.pub(self.config.zmq_stage_addr)
self.counter = CounterSender()
if self.config.debug_map:
debug_color = SrgbaColor(0.,0.,1.,1.)
self.debug_lines = RenderableLines(load_lines_from_svg(self.config.debug_map, 100, debug_color))
def run(self):
while self.run_loop_capped_fps(self.FPS):
dt = max(1/ self.FPS, self.dt_since_last_tick) # never dt of 0
self.loop_receive()
self.loop_update_scenarios()
self.loop_render(dt)
def loop_receive(self):
# 1) receive predictions
try:
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
for track_id, track in prediction_frame.tracks.items():
proj_track = ProjectedTrack(track, prediction_frame.camera)
self.scenarios[track_id].recv_prediction(proj_track)
except zmq.ZMQError as e:
# no msgs
pass
# 2) receive tracker tracks
try:
trajectory_frame: Frame = self.trajectory_sock.recv_pyobj(zmq.NOBLOCK)
for track_id, track in trajectory_frame.tracks.items():
proj_track = ProjectedTrack(track, trajectory_frame.camera)
self.scenarios[track_id].recv_track(proj_track)
except zmq.ZMQError as e:
pass
# self.logger.debug(f'reuse tracks')
def loop_update_scenarios(self):
"""Update active scenarios and handle pauses/completions."""
# 1) process timestep for all scenarios
for s in self.scenarios.values():
s.update()
# 2) Remove stale tracks and take-overs
for track_id, scenario in list(self.scenarios.items()):
if scenario.lost_factor() >= 1:
if scenario in self.active_scenarios:
self.active_scenarios = list(filter(scenario.__ne__, self.active_scenarios))
self.logger.info(f"rm lost track {track_id}")
del self.scenarios[track_id]
if scenario.takeover_factor() >= 1:
if scenario in self.active_scenarios:
self.active_scenarios = list(filter(scenario.__ne__, self.active_scenarios))
scenario.taken_over()
# 3) determine set of pending scenarios (all except running)
pending_scenarios = [s for s in self.scenarios.values() if s not in self.active_scenarios]
# ... highest priority first
pending_scenarios.sort(key=lambda s: s.scene.value.priority, reverse=True)
# 4) check if there's a slot free:
while len(self.active_scenarios) < self.config.max_active_scenarios and len(pending_scenarios):
scenario = pending_scenarios.pop(0)
self.active_scenarios.append(scenario)
scenario.start()
# 5) Takeover Logic: If no space, try to replace a lower-priority active scenario
# which is in a scene in which takeover is possible
eligible_active_scenarios = [
s for s in self.active_scenarios if s.scene.value.takeover_possible
]
eligible_active_scenarios.sort(key=lambda s: s.scene.value.priority)
if eligible_active_scenarios and pending_scenarios:
lowest_priority_active = eligible_active_scenarios[0]
highest_priority_waiting = pending_scenarios[0]
if highest_priority_waiting.scene.value.priority > lowest_priority_active.scene.value.priority:
# Takeover! Stop the active scenario
# will be cleaned up in update() loop after animation finishes
# automatically triggering the start of the highest priority scene
lowest_priority_active.take_over()
def loop_render(self, dt: DeltaT):
"""Draw all active scenarios onto the canvas."""
lines = RenderableLines([])
for scenario in self.active_scenarios:
lines.append_lines(scenario.to_renderable_lines(dt))
rl = lines.as_simplified(SimplifyMethod.RDP, .003) # or segmentise (see shapely)
self.counter.set("stage.lines", len(lines.lines))
self.counter.set("stage.points_orig", lines.point_count())
self.counter.set("stage.points", rl.point_count())
layers: RenderableLayers = {
1: lines,
2: self.debug_lines,
}
self.stage_sock.send_json(obj=layers, cls=DataclassJSONEncoder)
@classmethod
def arg_parser(cls) -> ArgumentParser:
argparser = ArgumentParser()
argparser.add_argument('--zmq-frame-noimg-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds_frame2")
argparser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds_traj")
argparser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds_preds")
argparser.add_argument('--zmq-stage-addr',
help='Manually specity communication addr for the stage messages (the rendered lines)',
type=str,
default="tcp://0.0.0.0:99174")
argparser.add_argument('--debug-map',
help='specify a map (svg-file) from which to load lines which will be overlayed',
type=str,
default="../DATASETS/hof3/map_hof.svg")
argparser.add_argument('--max-active-scenarios',
help='Maximum number of active scenarios that can be drawn at once (to not overlod the laser)',
type=int,
default=2)
return argparser