WIP diffing

This commit is contained in:
Ruben van de Ven 2025-05-12 20:54:39 +02:00
parent d1703a7a86
commit 4415e2dcb6

View file

@ -16,7 +16,7 @@ import zmq
from sgan.sgan import data
from trap import shapes
from trap.base import DataclassJSONEncoder, Frame, Track
from trap.base import Camera, DataclassJSONEncoder, DistortedCamera, Frame, Track
from trap.counter import CounterSender
from trap.laser_renderer import circle_points, rotateMatrix
from trap.lines import RenderableLine, RenderableLines, RenderablePoint, SrgbaColor, circle_arc
@ -25,6 +25,8 @@ from trap.timer import Timer
from trap.utils import exponentialDecay, relativePointToPolar, relativePolarToPoint
logger = logging.getLogger('trap.stage')
class ScenarioScene(Enum):
DETECTED = 1
FIRST_PREDICTION = 2
@ -34,7 +36,7 @@ class ScenarioScene(Enum):
LOST = -1
LOST_FADEOUT = 3
PREDICTION_INTERVAL: float|None = 15 # frames
PREDICTION_INTERVAL: float|None = 20 # frames
PREDICTION_FADE_IN: float = 3
PREDICTION_FADE_SLOPE: float = -10
PREDICTION_FADE_AFTER_DURATION: float = 10 # seconds
@ -247,6 +249,7 @@ class DrawnScenario(TrackScenario):
# 3. predictions
self.drawn_predictions = []
self.drawn_diffs = []
for a, (ptrack, next_ptrack) in enumerate(zip(self._predictions, [*self._predictions[1:], None])):
prediction = ptrack.predictions[0] # only use one prediction per timestep/frame/track
@ -254,10 +257,29 @@ class DrawnScenario(TrackScenario):
# not the last one, cut off
next_ptrack: Track = self._predictions[a+1]
end_step = next_ptrack.frame_index - ptrack.frame_index
# diff
diff_steps_back = ptrack.frame_index - self._track.frame_index
if len(self.drawn_positions) < -1*diff_steps_back:
logger.warning("Track history doesn't reach prediction start. Should not be possible. Skip")
pass
else:
# trajectory_range = self.camera.[d.get_foot_coords() for d in trajectory_det_range] # in frame coordinate space
trajectory_range = self.drawn_positions[diff_steps_back:diff_steps_back+end_step]
prediction_range = ptrack.predictions[0][:end_step] # in world coordinate space
line = []
for p1, p2 in zip(trajectory_range[::4], prediction_range[::4]):
line.extend([
p1, p2
])
if len(line):
self.drawn_diffs.append(line)
else:
end_step = None # not last item; show all
self.drawn_predictions.append(ptrack.predictions[0][:end_step])
@ -340,6 +362,11 @@ class DrawnScenario(TrackScenario):
# points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction[PREDICTION_OFFSET:], colors[PREDICTION_OFFSET:])]
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction, colors)]
lines.append(RenderableLine(points))
for drawn_diff in self.drawn_diffs:
color = SrgbaColor(0.,1,1.,1.-self.lost_factor())
colors = [color.as_faded(1) for a2 in range(len(drawn_diff))]
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_diff, colors)]
lines.append(RenderableLine(points))
# # print(self.current_state)
# if self.current_state is self.first_prediction or self.current_state is self.corrected_prediction:
@ -460,6 +487,7 @@ class Stage(Node):
self.stage_sock = self.pub(self.config.zmq_stage_addr)
self.counter = CounterSender()
self.camera: Optional[DistortedCamera] = None
def run(self):