WIP diffing
This commit is contained in:
parent
d1703a7a86
commit
4415e2dcb6
1 changed files with 30 additions and 2 deletions
|
@ -16,7 +16,7 @@ import zmq
|
||||||
|
|
||||||
from sgan.sgan import data
|
from sgan.sgan import data
|
||||||
from trap import shapes
|
from trap import shapes
|
||||||
from trap.base import DataclassJSONEncoder, Frame, Track
|
from trap.base import Camera, DataclassJSONEncoder, DistortedCamera, Frame, Track
|
||||||
from trap.counter import CounterSender
|
from trap.counter import CounterSender
|
||||||
from trap.laser_renderer import circle_points, rotateMatrix
|
from trap.laser_renderer import circle_points, rotateMatrix
|
||||||
from trap.lines import RenderableLine, RenderableLines, RenderablePoint, SrgbaColor, circle_arc
|
from trap.lines import RenderableLine, RenderableLines, RenderablePoint, SrgbaColor, circle_arc
|
||||||
|
@ -25,6 +25,8 @@ from trap.timer import Timer
|
||||||
from trap.utils import exponentialDecay, relativePointToPolar, relativePolarToPoint
|
from trap.utils import exponentialDecay, relativePointToPolar, relativePolarToPoint
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('trap.stage')
|
||||||
|
|
||||||
class ScenarioScene(Enum):
|
class ScenarioScene(Enum):
|
||||||
DETECTED = 1
|
DETECTED = 1
|
||||||
FIRST_PREDICTION = 2
|
FIRST_PREDICTION = 2
|
||||||
|
@ -34,7 +36,7 @@ class ScenarioScene(Enum):
|
||||||
LOST = -1
|
LOST = -1
|
||||||
|
|
||||||
LOST_FADEOUT = 3
|
LOST_FADEOUT = 3
|
||||||
PREDICTION_INTERVAL: float|None = 15 # frames
|
PREDICTION_INTERVAL: float|None = 20 # frames
|
||||||
PREDICTION_FADE_IN: float = 3
|
PREDICTION_FADE_IN: float = 3
|
||||||
PREDICTION_FADE_SLOPE: float = -10
|
PREDICTION_FADE_SLOPE: float = -10
|
||||||
PREDICTION_FADE_AFTER_DURATION: float = 10 # seconds
|
PREDICTION_FADE_AFTER_DURATION: float = 10 # seconds
|
||||||
|
@ -247,6 +249,7 @@ class DrawnScenario(TrackScenario):
|
||||||
|
|
||||||
# 3. predictions
|
# 3. predictions
|
||||||
self.drawn_predictions = []
|
self.drawn_predictions = []
|
||||||
|
self.drawn_diffs = []
|
||||||
for a, (ptrack, next_ptrack) in enumerate(zip(self._predictions, [*self._predictions[1:], None])):
|
for a, (ptrack, next_ptrack) in enumerate(zip(self._predictions, [*self._predictions[1:], None])):
|
||||||
|
|
||||||
prediction = ptrack.predictions[0] # only use one prediction per timestep/frame/track
|
prediction = ptrack.predictions[0] # only use one prediction per timestep/frame/track
|
||||||
|
@ -254,10 +257,29 @@ class DrawnScenario(TrackScenario):
|
||||||
# not the last one, cut off
|
# not the last one, cut off
|
||||||
next_ptrack: Track = self._predictions[a+1]
|
next_ptrack: Track = self._predictions[a+1]
|
||||||
end_step = next_ptrack.frame_index - ptrack.frame_index
|
end_step = next_ptrack.frame_index - ptrack.frame_index
|
||||||
|
|
||||||
|
# diff
|
||||||
|
diff_steps_back = ptrack.frame_index - self._track.frame_index
|
||||||
|
if len(self.drawn_positions) < -1*diff_steps_back:
|
||||||
|
logger.warning("Track history doesn't reach prediction start. Should not be possible. Skip")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# trajectory_range = self.camera.[d.get_foot_coords() for d in trajectory_det_range] # in frame coordinate space
|
||||||
|
trajectory_range = self.drawn_positions[diff_steps_back:diff_steps_back+end_step]
|
||||||
|
prediction_range = ptrack.predictions[0][:end_step] # in world coordinate space
|
||||||
|
line = []
|
||||||
|
for p1, p2 in zip(trajectory_range[::4], prediction_range[::4]):
|
||||||
|
line.extend([
|
||||||
|
p1, p2
|
||||||
|
])
|
||||||
|
if len(line):
|
||||||
|
self.drawn_diffs.append(line)
|
||||||
else:
|
else:
|
||||||
end_step = None # not last item; show all
|
end_step = None # not last item; show all
|
||||||
self.drawn_predictions.append(ptrack.predictions[0][:end_step])
|
self.drawn_predictions.append(ptrack.predictions[0][:end_step])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -340,6 +362,11 @@ class DrawnScenario(TrackScenario):
|
||||||
# points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction[PREDICTION_OFFSET:], colors[PREDICTION_OFFSET:])]
|
# points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction[PREDICTION_OFFSET:], colors[PREDICTION_OFFSET:])]
|
||||||
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction, colors)]
|
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_prediction, colors)]
|
||||||
lines.append(RenderableLine(points))
|
lines.append(RenderableLine(points))
|
||||||
|
for drawn_diff in self.drawn_diffs:
|
||||||
|
color = SrgbaColor(0.,1,1.,1.-self.lost_factor())
|
||||||
|
colors = [color.as_faded(1) for a2 in range(len(drawn_diff))]
|
||||||
|
points = [RenderablePoint(pos, pos_color) for pos, pos_color in zip(drawn_diff, colors)]
|
||||||
|
lines.append(RenderableLine(points))
|
||||||
|
|
||||||
# # print(self.current_state)
|
# # print(self.current_state)
|
||||||
# if self.current_state is self.first_prediction or self.current_state is self.corrected_prediction:
|
# if self.current_state is self.first_prediction or self.current_state is self.corrected_prediction:
|
||||||
|
@ -460,6 +487,7 @@ class Stage(Node):
|
||||||
self.stage_sock = self.pub(self.config.zmq_stage_addr)
|
self.stage_sock = self.pub(self.config.zmq_stage_addr)
|
||||||
|
|
||||||
self.counter = CounterSender()
|
self.counter = CounterSender()
|
||||||
|
self.camera: Optional[DistortedCamera] = None
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
Loading…
Reference in a new issue