WIP diffing

This commit is contained in:
Ruben van de Ven 2025-05-12 20:54:39 +02:00
parent d1703a7a86
commit 4415e2dcb6

View file

@ -16,7 +16,7 @@ import zmq
from sgan.sgan import data from sgan.sgan import data
from trap import shapes from trap import shapes
from trap.base import DataclassJSONEncoder, Frame, Track from trap.base import Camera, DataclassJSONEncoder, DistortedCamera, Frame, Track
from trap.counter import CounterSender from trap.counter import CounterSender
from trap.laser_renderer import circle_points, rotateMatrix from trap.laser_renderer import circle_points, rotateMatrix
from trap.lines import RenderableLine, RenderableLines, RenderablePoint, SrgbaColor, circle_arc from trap.lines import RenderableLine, RenderableLines, RenderablePoint, SrgbaColor, circle_arc
@ -25,6 +25,8 @@ from trap.timer import Timer
from trap.utils import exponentialDecay, relativePointToPolar, relativePolarToPoint from trap.utils import exponentialDecay, relativePointToPolar, relativePolarToPoint
logger = logging.getLogger('trap.stage')
class ScenarioScene(Enum): class ScenarioScene(Enum):
DETECTED = 1 DETECTED = 1
FIRST_PREDICTION = 2 FIRST_PREDICTION = 2
@ -34,7 +36,7 @@ class ScenarioScene(Enum):
LOST = -1 LOST = -1
LOST_FADEOUT = 3 LOST_FADEOUT = 3
PREDICTION_INTERVAL: float|None = 15 # frames PREDICTION_INTERVAL: float|None = 20 # frames
PREDICTION_FADE_IN: float = 3 PREDICTION_FADE_IN: float = 3
PREDICTION_FADE_SLOPE: float = -10 PREDICTION_FADE_SLOPE: float = -10
PREDICTION_FADE_AFTER_DURATION: float = 10 # seconds PREDICTION_FADE_AFTER_DURATION: float = 10 # seconds
@ -247,6 +249,7 @@ class DrawnScenario(TrackScenario):
# 3. predictions # 3. predictions
self.drawn_predictions = [] self.drawn_predictions = []
self.drawn_diffs = []
for a, (ptrack, next_ptrack) in enumerate(zip(self._predictions, [*self._predictions[1:], None])): for a, (ptrack, next_ptrack) in enumerate(zip(self._predictions, [*self._predictions[1:], None])):
prediction = ptrack.predictions[0] # only use one prediction per timestep/frame/track prediction = ptrack.predictions[0] # only use one prediction per timestep/frame/track
@ -254,10 +257,29 @@ class DrawnScenario(TrackScenario):
# not the last one, cut off # not the last one, cut off
next_ptrack: Track = self._predictions[a+1] next_ptrack: Track = self._predictions[a+1]
end_step = next_ptrack.frame_index - ptrack.frame_index end_step = next_ptrack.frame_index - ptrack.frame_index
# diff
diff_steps_back = ptrack.frame_index - self._track.frame_index
if len(self.drawn_positions) < -1*diff_steps_back:
logger.warning("Track history doesn't reach prediction start. Should not be possible. Skip")
pass
else:
# trajectory_range = self.camera.[d.get_foot_coords() for d in trajectory_det_range] # in frame coordinate space
trajectory_range = self.drawn_positions[diff_steps_back:diff_steps_back+end_step]
prediction_range = ptrack.predictions[0][:end_step] # in world coordinate space
line = []
for p1, p2 in zip(trajectory_range[::4], prediction_range[::4]):
line.extend([
p1, p2
])
if len(line):
self.drawn_diffs.append(line)
else: else:
end_step = None # not last item; show all end_step = None # not last item; show all
self.drawn_predictions.append(ptrack.predictions[0][:end_step]) self.drawn_predictions.append(ptrack.predictions[0][:end_step])
@ -340,6 +362,11 @@ class DrawnScenario(TrackScenario):
# points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction[PREDICTION_OFFSET:], colors[PREDICTION_OFFSET:])] # points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction[PREDICTION_OFFSET:], colors[PREDICTION_OFFSET:])]
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction, colors)] points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction, colors)]
lines.append(RenderableLine(points)) lines.append(RenderableLine(points))
for drawn_diff in self.drawn_diffs:
color = SrgbaColor(0.,1,1.,1.-self.lost_factor())
colors = [color.as_faded(1) for a2 in range(len(drawn_diff))]
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_diff, colors)]
lines.append(RenderableLine(points))
# # print(self.current_state) # # print(self.current_state)
# if self.current_state is self.first_prediction or self.current_state is self.corrected_prediction: # if self.current_state is self.first_prediction or self.current_state is self.corrected_prediction:
@ -460,6 +487,7 @@ class Stage(Node):
self.stage_sock = self.pub(self.config.zmq_stage_addr) self.stage_sock = self.pub(self.config.zmq_stage_addr)
self.counter = CounterSender() self.counter = CounterSender()
self.camera: Optional[DistortedCamera] = None
def run(self): def run(self):