Refactoring for NN, scenarios, laser tests

This commit is contained in:
Ruben van de Ven 2025-04-03 21:02:39 +02:00
parent e3224aa47f
commit eee73d675a

125
trap/scenarios.py Normal file
View file

@ -0,0 +1,125 @@
from enum import Enum
import time
from typing import Optional
from statemachine import Event, State, StateMachine
from statemachine.exceptions import TransitionNotAllowed
from trap.base import Track
class ScenarioScene(Enum):
DETECTED = 1
FIRST_PREDICTION = 2
CORRECTED_PREDICTION = 3
LOITERING = 4
PLAY = 4
LOST = -1
class TrackScenario(StateMachine):
detected = State(initial=True)
substantial = State()
first_prediction = State()
corrected_prediction = State()
loitering = State()
play = State()
lost = State(final=True)
receive_track = lost.from_(
detected, first_prediction, corrected_prediction, loitering, play, substantial, cond="track_is_lost"
) | corrected_prediction.to(loitering, cond="track_is_loitering") | detected.to(substantial, cond="track_is_long")
receive_prediction = detected.to(first_prediction) | first_prediction.to(corrected_prediction, cond="prediction_is_stale") | corrected_prediction.to(play, cond="prediction_is_playing")
def __init__(self, track: Track):
self._track = track
self.first_prediction_track: Optional[Track] = None
self.prediction_track: Optional[Track] = None
super().__init__()
def track_is_long(self, track: Track):
return len(track.history) > 20
def track_is_lost(self, track: Track):
return track.lost
def track_is_loitering(self, track: Track):
# TODO)) Change to measure displacement over the last n seconds
return len(track.history) > (track.fps * 60) # seconds after which someone is loitering
def prediction_is_stale(self, track: Track):
# TODO use displacement instead of time
return bool(self.prediction_track and self.prediction_track.created_at < (time.perf_counter() - 2))
def prediction_is_playing(self, Track):
return False
# @property
# def track(self):
# return self._track
def set_track(self, track: Track):
self._track = track
try:
self.receive_track(track)
except TransitionNotAllowed as e:
# state change is optional
pass
def set_prediction(self, track: Track):
if not self.first_prediction_track:
self.first_prediction_track = track
self.prediction_track = track
try:
self.receive_prediction(track)
except TransitionNotAllowed as e:
# state change is optional
pass
def after_receive_track(self, track: Track):
print('change state')
def on_receive_track(self, track: Track):
# on event, because it happens for every receive, despite transition
print('updating track!')
# self.track = track
def on_receive_prediction(self, track: Track):
# on event, because it happens for every receive, despite transition
print('updating prediction!')
# self.track = track
def after_receive_prediction(self, track: Track):
# after
self.prediction_track = track
if not self.first_prediction_track:
self.first_prediction_track = track
def on_enter_corrected_prediction(self):
print('corrected!')
def on_enter_detected(self):
print("DETECTED!")
def on_enter_first_prediction(self):
print("Hello!")
def on_enter_detected(self):
print(f"enter {self.current_state.id}")
def on_enter_substantial(self):
print(f"enter {self.current_state.id}")
def on_enter_first_prediction(self):
print(f"enter {self.current_state.id}")
def on_enter_corrected_prediction(self):
print(f"enter {self.current_state.id}")
def on_enter_loitering(self):
print(f"enter {self.current_state.id}")
def on_enter_play(self):
print(f"enter {self.current_state.id}")
def on_enter_lost(self):
print(f"enter {self.current_state.id}")