For dortmund
This commit is contained in:
parent
bd7f0cf2f3
commit
8dcf70959b
15 changed files with 368 additions and 137 deletions
|
|
@ -30,6 +30,8 @@ These are roughly the steps to go from datagathering to training
|
|||
* See [[tests/trajectron_maps.ipynb]] for more info how to do so (e.g. the homography map/scale settings, which are also set in process_data)
|
||||
|
||||
5. Train Trajectron model `uv run trajectron_train --eval_every 10 --vis_every 1 --train_data_dict NAME_train.pkl --eval_data_dict NAME_val.pkl --offline_scene_graph no --preprocess_workers 8 --log_dir EXPERIMENTS/models --log_tag _NAME --train_epochs 100 --conf EXPERIMENTS/config.json --batch_size 256 --data_dir EXPERIMENTS/trajectron-data `
|
||||
* For faster training disalble edges:
|
||||
` uv run trajectron_train --eval_every 200 --train_data_dict dortmund-nostep-nosmooth-noise2-offsets1-f2.0-map-2025-11-11_train.pkl --eval_data_dict dortmund-nostep-nosmooth-noise2-offsets1-f2.0-map-2025-11-11_val.pkl --offline_scene_graph no --preprocess_workers 8 --log_dir /home/ruben/suspicion/trap/SETTINGS/2025-11-dortmund/models --log_tag _dortmund-nostep-nosmooth-noise2-offsets1-f2.0-map-2025-11-11 --train_epochs 100 --conf /home/ruben/suspicion/trap/SETTINGS/2025-11-dortmund/trajectron.json --data_dir SETTINGS/2025-11-dortmund/trajectron --map_encoding --no_edge_encoding --dynamic_edges yes --no_edge_encoding --edge_influence_combine_method max --batch_size 512`
|
||||
6. The run!
|
||||
* `uv run supervisord`
|
||||
<!-- * On a video file (you can use a wildcard) `DISPLAY=:1 uv run trapserv --remote-log-addr 100.69.123.91 --eval_device cuda:0 --detector ultralytics --homography ../DATASETS/NAME/homography.json --eval_data_dict EXPERIMENTS/trajectron-data/hof2s-m_test.pkl --video-src ../DATASETS/NAME/*.mp4 --model_dir EXPERIMENTS/models/models_DATE_NAME/--smooth-predictions --smooth-tracks --num-samples 3 --render-window --calibration ../DATASETS/NAME/calibration.json` (the DISPLAY environment variable is used here to running over SSH connection and display on local monitor)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
"hidden_channels": [
|
||||
10,
|
||||
20,
|
||||
10,
|
||||
5,
|
||||
1
|
||||
],
|
||||
"output_size": 32,
|
||||
|
|
@ -41,7 +41,7 @@
|
|||
}
|
||||
},
|
||||
"k": 1,
|
||||
"k_eval": 25,
|
||||
"k_eval": 1,
|
||||
"kl_min": 0.07,
|
||||
"kl_weight": 100.0,
|
||||
"kl_weight_start": 0,
|
||||
|
|
@ -52,8 +52,8 @@
|
|||
"dropout_keep_prob": 0.75
|
||||
},
|
||||
"MLP_dropout_keep_prob": 0.9,
|
||||
"enc_rnn_dim_edge": 32,
|
||||
"enc_rnn_dim_edge_influence": 32,
|
||||
"enc_rnn_dim_edge": 1,
|
||||
"enc_rnn_dim_edge_influence": 1,
|
||||
"enc_rnn_dim_history": 32,
|
||||
"enc_rnn_dim_future": 32,
|
||||
"dec_rnn_dim": 128,
|
||||
|
|
@ -105,7 +105,7 @@
|
|||
"log_histograms": false,
|
||||
"dynamic_edges": "yes",
|
||||
"edge_state_combine_method": "sum",
|
||||
"edge_influence_combine_method": "attention",
|
||||
"edge_influence_combine_method": "max",
|
||||
"edge_addition_filter": [
|
||||
0.25,
|
||||
0.5,
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ directory=%(here)s
|
|||
|
||||
[program:predictor]
|
||||
# command=uv run trap_prediction --eval_device cuda:0 --model_dir EXPERIMENTS/models/models_20241229_21_35_13_hof3-m2-ud-split-conv12-f2.0-map-2024-12-29/ --num-samples 1 --map_encoding --eval_data_dict EXPERIMENTS/trajectron-data/hof3-m2-ud-split-nostep-conv12-f2.0-map-2024-12-29_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode
|
||||
command=uv run trap_prediction --eval_device cuda:0 --model_dir EXPERIMENTS/models/models_20251107_14_27_55_hof-lidar-m2-ud-nostep-kalsmooth-noise2-offsets1-f2.0-map-2025-11-07/ --num-samples 1 --map_encoding --eval_ata_dict EXPERIMENTS/trajectron-data/hof-lidar-m2-ud-nostep-kalsmooth-noise2-offsets1-f2.0-map-2025-11-07_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode
|
||||
command=uv run trap_prediction --eval_device cuda:0 --model_dir SETTINGS/2025-11-dortmund/models/models_20251111_19_06_29_dortmund-nostep-nosmooth-noise2-offsets1-f2.0-map-2025-11-11/ --num-samples 1 --map_encoding --eval_data_dict SETTINGS/2025-11-dortmund/trajectron/dortmund-nostep-nosmooth-noise2-offsets1-f2.0-map-2025-11-12_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode --conf SETTINGS/2025-11-dortmund/trajectron.json
|
||||
# command=uv run trap_prediction --eval_device cuda:0 --model_dir EXPERIMENTS/models/models_20251106_11_51_00_hof-lidar-m2-ud-nostep-kalsmooth-noise2-offsets2-f2.0-map-2025-11-06/ --num-samples 1 --map_encoding --eval_data_dict EXPERIMENTS/trajectron-data/hof-lidar-m2-ud-nostep-kalsmooth-noise2-offsets2-f2.0-map-2025-11-06_val.pkl --prediction-horizon 120 --gmm-mode True --z-mode
|
||||
# uv run trajectron_train --continue_training_from EXPERIMENTS/models/models_20241229_21_35_13_hof3-m2-ud-split-conv12-f2.0-map-2024-12-29/ --eval_every 5 --train_data_dict hof3-nostep-conv12-f2.0-map-2024-12-27_train.pkl --eval_data_dict hof3-nostep-conv12-f2.0-map-2024-12-27_val.pkl --offline_scene_graph no --preprocess_workers 8 --log_dir EXPERIMENTS/models --log_tag _hof3-conv12-f2.0-map-2024-12-27 --train_epochs 10 --conf EXPERIMENTS/config.json --data_dir EXPERIMENTS/trajectron-data --map_encoding
|
||||
directory=%(here)s
|
||||
|
|
@ -94,3 +94,4 @@ stopwaitsecs=60
|
|||
[program:superfsmon]
|
||||
command=superfsmon trap/stage.py stage
|
||||
directory=%(here)s
|
||||
autostart=false
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -74,6 +74,7 @@ class DetectionState(IntFlag):
|
|||
Confirmed = 2 # after tentative
|
||||
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
|
||||
Interpolated = 8 # A position estimated through interpolation of adjecent detections
|
||||
# Interpolated = 8 # A position estimated through interpolation of adjecent detections
|
||||
|
||||
@classmethod
|
||||
def from_deepsort_track(cls, track: DeepsortTrack):
|
||||
|
|
@ -89,11 +90,13 @@ class DetectionState(IntFlag):
|
|||
def from_bytetrack_track(cls, track: ByteTrackTrack):
|
||||
if track.state == ByteTrackTrackState.New:
|
||||
return cls.Tentative
|
||||
if track.state == ByteTrackTrackState.Lost:
|
||||
if track.state == ByteTrackTrackState.Removed:
|
||||
return cls.Lost
|
||||
# if track.time_since_update > 0:
|
||||
if track.state == ByteTrackTrackState.Tracked:
|
||||
return cls.Confirmed
|
||||
if track.state == ByteTrackTrackState.Lost:
|
||||
return cls.Tentative
|
||||
raise RuntimeError("Should not run into Deleted entries here")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import pyglet
|
|||
import zmq
|
||||
from pyglet import shapes
|
||||
|
||||
from trap.base import Detection
|
||||
from trap.base import Detection, UndistortedCamera
|
||||
from trap.counter import CounterListerner
|
||||
from trap.frame_emitter import Frame, Track
|
||||
from trap.lines import load_lines_from_svg
|
||||
|
|
@ -151,10 +151,12 @@ class CvRenderer(Node):
|
|||
# logger.debug(f'new video frame {frame.index}')
|
||||
|
||||
|
||||
if self.frame is None:
|
||||
if self.frame is None and i < 100:
|
||||
# might need to wait a few iterations before first frame comes available
|
||||
time.sleep(.1)
|
||||
continue
|
||||
else:
|
||||
self.frame = Frame(i, np.zeros((1920,1080,3)), camera=UndistortedCamera(12))
|
||||
|
||||
try:
|
||||
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
|
|
|
|||
|
|
@ -208,6 +208,7 @@ class Lidar(Node):
|
|||
|
||||
self.track_sock = self.pub(self.config.zmq_trajectory_addr)
|
||||
self.detection_sock = self.pub(self.config.zmq_detection_addr)
|
||||
self.lost_track_sock = self.pub(self.config.zmq_lost_addr)
|
||||
|
||||
calibration = vd.Calibration.read( "VLP-16.yaml")
|
||||
config = vd.Config(model=vd.Model.VLP16, calibration=calibration)
|
||||
|
|
@ -286,12 +287,12 @@ class Lidar(Node):
|
|||
self.map_outline_volume.bounding_polygon = o3d.utility.Vector3dVector(polygon_points)
|
||||
|
||||
|
||||
if self.config.smooth_tracks:
|
||||
# TODO)) make configurable
|
||||
logger.info(f"Smoother enabled, assuming {ASSUMED_FPS} fps")
|
||||
self.smoother = Smoother(window_len=int(ASSUMED_FPS*.6), convolution=True)
|
||||
else:
|
||||
logger.info("Smoother Disabled (enable with --smooth-tracks)")
|
||||
# if self.config.smooth_tracks:
|
||||
# # TODO)) make configurable
|
||||
# logger.info(f"Smoother enabled, assuming {ASSUMED_FPS} fps")
|
||||
# self.smoother = Smoother(window_len=int(ASSUMED_FPS*.6), convolution=True)
|
||||
# else:
|
||||
# logger.info("Smoother Disabled (enable with --smooth-tracks)")
|
||||
|
||||
self.remotes = {}
|
||||
|
||||
|
|
@ -487,7 +488,7 @@ class Lidar(Node):
|
|||
filtered_pcd = denoised_pcd
|
||||
|
||||
# down sample
|
||||
filtered_pcd = filtered_pcd.voxel_down_sample(voxel_size=0.04)
|
||||
filtered_pcd = filtered_pcd.voxel_down_sample(voxel_size=0.06)
|
||||
stat_downsampled = len(filtered_pcd.points)
|
||||
timers.append(('downsample', time.perf_counter()))
|
||||
|
||||
|
|
@ -527,7 +528,8 @@ class Lidar(Node):
|
|||
self.logger.debug(f"online tracks: {[t[4] for t in online_tracks]}")
|
||||
removed_tracks = self.tracker.removed_stracks
|
||||
# active_stracks = [track for track in self.tracker.tracked_stracks if track.is_activated]
|
||||
active_stracks = [track for track in self.tracker.tracked_stracks if track.is_activated]
|
||||
# we want both the lost, and currently visible, as otherwise, the predictor goes haywire
|
||||
active_stracks = [track for track in (self.tracker.tracked_stracks + self.tracker.lost_stracks) if track.is_activated]
|
||||
detections = [Detection.from_bytetrack(track, frame_idx) for track in active_stracks]
|
||||
|
||||
counter.set('detections', len(detections))
|
||||
|
|
@ -549,6 +551,7 @@ class Lidar(Node):
|
|||
for t in removed_tracks:
|
||||
if t.track_id in self.tracks:
|
||||
if t.is_activated:
|
||||
self.lost_track_sock.send_pyobj(self.tracks[t.track_id])
|
||||
self.logger.info(f"Lost track: {t.track_id}")
|
||||
del self.tracks[t.track_id]
|
||||
# TODO: fix this oddity:
|
||||
|
|
@ -567,8 +570,8 @@ class Lidar(Node):
|
|||
frame = Frame(frame_idx, None, time.time(), active_tracks,
|
||||
camera.H, camera)
|
||||
|
||||
if self.config.smooth_tracks:
|
||||
frame = self.smoother.smooth_frame_tracks(frame)
|
||||
# if self.config.smooth_tracks:
|
||||
# frame = self.smoother.smooth_frame_tracks(frame)
|
||||
|
||||
timers.append(('smooth', time.perf_counter()))
|
||||
|
||||
|
|
@ -669,6 +672,10 @@ class Lidar(Node):
|
|||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
argparser.add_argument('--zmq-lost-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_lost")
|
||||
argparser.add_argument('--zmq-detection-addr',
|
||||
help='Manually specity communication addr for the detection messages',
|
||||
type=str,
|
||||
|
|
@ -698,9 +705,9 @@ class Lidar(Node):
|
|||
help='Maximum area (m2) of bounding boxes to consider them for tracking',
|
||||
type=float,
|
||||
default=2)
|
||||
argparser.add_argument("--smooth-tracks",
|
||||
help="Smooth the tracker tracks before sending them to the predictor",
|
||||
action='store_true')
|
||||
# argparser.add_argument("--smooth-tracks",
|
||||
# help="Smooth the tracker tracks before sending them to the predictor",
|
||||
# action='store_true')
|
||||
argparser.add_argument("--viz",
|
||||
help="Render pointclouds in open3d",
|
||||
action='store_true')
|
||||
|
|
@ -902,7 +909,7 @@ def get_cluster_boxes(clusters, min_area = 0, max_area=5):
|
|||
|
||||
area = (x_max-x_min) * (y_max - y_min)
|
||||
if area < min_area or area > max_area:
|
||||
logger.warning(f"Dropping box {area} ")
|
||||
logger.debug(f"Dropping box {area} ")
|
||||
continue
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1198,6 +1198,7 @@ class RotatingLine(LineAnimator):
|
|||
# find closest point to start from:
|
||||
|
||||
origin = target_line.points[0]
|
||||
# print(origin, target_line.points[-1])
|
||||
# closest_idx = StartFromClosestPoint.find_closest_to_point(origin.position, [p.position for p in self.drawn_points])
|
||||
# if closest_idx:
|
||||
# print('cut at', closest_idx)
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ class Node():
|
|||
def push(self, addr: str):
|
||||
"push-pull pair"
|
||||
sock = self.zmq_context.socket(zmq.PUSH)
|
||||
# sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.connect(addr)
|
||||
return sock
|
||||
|
||||
|
|
@ -210,5 +211,7 @@ def setup_logging(config: Namespace):
|
|||
|
||||
logging.basicConfig(
|
||||
level=loglevel,
|
||||
handlers=log_handlers # [queue_handler]
|
||||
handlers=log_handlers, # [queue_handler]
|
||||
format="%(asctime)s %(levelname)s:%(name)s:%(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
|
|
@ -13,6 +13,7 @@ from multiprocessing import Event
|
|||
|
||||
import dill
|
||||
import numpy as np
|
||||
import shapely
|
||||
import torch
|
||||
import zmq
|
||||
from trajectron.environment import Environment, Scene, GeometricMap
|
||||
|
|
@ -21,6 +22,7 @@ from trajectron.model.online.online_trajectron import OnlineTrajectron
|
|||
from trajectron.utils import prediction_output_to_trajectories
|
||||
|
||||
from trap.frame_emitter import DataclassJSONEncoder, Frame
|
||||
from trap.lines import load_lines_from_svg
|
||||
from trap.node import Node
|
||||
from trap.tracker import Smoother
|
||||
from trap.utils import ImageMap
|
||||
|
|
@ -52,14 +54,16 @@ def create_online_env(env, hyperparams, scene_idx, init_timestep):
|
|||
init_timestep + 1),
|
||||
state=hyperparams['state'])
|
||||
online_scene.robot = test_scene.robot
|
||||
online_scene.calculate_scene_graph(attention_radius=env.attention_radius,
|
||||
radius = {k: 0 for k,v in env.attention_radius.items()}
|
||||
|
||||
online_scene.calculate_scene_graph(attention_radius=radius,
|
||||
edge_addition_filter=hyperparams['edge_addition_filter'],
|
||||
edge_removal_filter=hyperparams['edge_removal_filter'])
|
||||
|
||||
return Environment(node_type_list=env.node_type_list,
|
||||
standardization=env.standardization,
|
||||
scenes=[online_scene],
|
||||
attention_radius=env.attention_radius,
|
||||
attention_radius=radius,
|
||||
robot_type=env.robot_type)
|
||||
|
||||
|
||||
|
|
@ -162,6 +166,15 @@ class PredictionServer(Node):
|
|||
self.prediction_socket = self.pub(self.config.zmq_prediction_addr)
|
||||
self.external_predictions = not self.config.zmq_prediction_addr.startswith("ipc://")
|
||||
|
||||
self.cutoff_shape = None
|
||||
if self.config.cutoff_map:
|
||||
|
||||
self.cutoff_line = load_lines_from_svg(self.config.cutoff_map, 100, '')[0]
|
||||
self.cutoff_shape = shapely.Polygon([p.position for p in self.cutoff_line.points])
|
||||
|
||||
logger.info(f"{self.cutoff_shape}")
|
||||
|
||||
|
||||
|
||||
def send_frame(self, frame: Frame):
|
||||
if self.external_predictions:
|
||||
|
|
@ -184,7 +197,8 @@ class PredictionServer(Node):
|
|||
# model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3'
|
||||
|
||||
# Load hyperparameters from json
|
||||
config_file = os.path.join(self.config.model_dir, self.config.conf)
|
||||
# config_file = os.path.join(self.config.model_dir, self.config.conf)
|
||||
config_file = self.config.conf
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError('Config json not found!')
|
||||
with open(config_file, 'r') as conf_json:
|
||||
|
|
@ -224,6 +238,9 @@ class PredictionServer(Node):
|
|||
logger.info(f"Basing online env on {eval_scene=} -- loaded from {self.config.eval_data_dict}")
|
||||
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
|
||||
|
||||
print("overriding attention radius")
|
||||
online_env.attention_radius = {(online_env.NodeType.PEDESTRIAN, online_env.NodeType.PEDESTRIAN): 0.1}
|
||||
|
||||
# auto-find highest iteration
|
||||
model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device)
|
||||
model_iterations = pathlib.Path(self.config.model_dir).glob('model_registrar-*.pt')
|
||||
|
|
@ -297,6 +314,7 @@ class PredictionServer(Node):
|
|||
|
||||
input_dict = {}
|
||||
for identifier, track in frame.tracks.items():
|
||||
|
||||
# if len(trajectory['history']) < 7:
|
||||
# # TODO: these trajectories should still be in the output, but without predictions
|
||||
# continue
|
||||
|
|
@ -313,7 +331,16 @@ class PredictionServer(Node):
|
|||
if len(track.history) < 2:
|
||||
continue
|
||||
|
||||
|
||||
|
||||
node = track.to_trajectron_node(frame.camera, online_env)
|
||||
|
||||
if self.cutoff_shape:
|
||||
position = shapely.Point(node.data.data[-1][:2])
|
||||
if not shapely.contains(self.cutoff_shape, position):
|
||||
# logger.debug(f"Skip position {position}")
|
||||
continue
|
||||
|
||||
# print(node.data.data[-1])
|
||||
input_dict[node] = np.array(object=node.data.data[-1])
|
||||
# print("history", node.data.data[-10:])
|
||||
|
|
@ -352,6 +379,7 @@ class PredictionServer(Node):
|
|||
# )
|
||||
|
||||
# input_dict[node] = np.array(object=[x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
|
||||
# break # only on
|
||||
|
||||
# print(input_dict)
|
||||
|
||||
|
|
@ -368,9 +396,11 @@ class PredictionServer(Node):
|
|||
continue
|
||||
|
||||
maps = None
|
||||
start_maps = time.time()
|
||||
if hyperparams['use_map_encoding']:
|
||||
maps = get_maps_for_input(input_dict, eval_scene, hyperparams, device=self.config.eval_device)
|
||||
|
||||
|
||||
# print(maps)
|
||||
|
||||
# robot_present_and_future = None
|
||||
|
|
@ -398,7 +428,8 @@ class PredictionServer(Node):
|
|||
gmm_mode=self.config.gmm_mode, # "If True: The mode of the Gaussian Mixture Model (GMM) is sampled (see trajectron.model.mgcvae.py)"
|
||||
z_mode=self.config.z_mode # "Predictions from the model’s most-likely high-level latent behavior mode" (see trajecton.models.components.discrete_latent:sample_p(most_likely_z=z_mode))
|
||||
)
|
||||
|
||||
print(len(dists), len (preds))
|
||||
intermediate = time.time()
|
||||
# unsure what this bit from online_prediction.py does:
|
||||
# detailed_preds_dict = dict()
|
||||
# for node in eval_scene.nodes:
|
||||
|
|
@ -418,8 +449,8 @@ class PredictionServer(Node):
|
|||
|
||||
|
||||
end = time.time()
|
||||
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges -- init: %.2f s" % (end - start,
|
||||
1. / (end - start), len(trajectron.nodes),
|
||||
logger.debug("took %.2f s (= %.2f Hz), maps: %.2f, forward: %.2f w/ %d nodes and %d edges -- init: %.2f s" % (end - start,
|
||||
1. / (end - start), (start-start_maps)/(end - start), (intermediate-start)/(end - start), len(trajectron.nodes),
|
||||
trajectron.scene_graph.get_num_edges(), start-t_init))
|
||||
|
||||
# if self.config.center_data:
|
||||
|
|
@ -441,7 +472,7 @@ class PredictionServer(Node):
|
|||
futures_dict = futures_dict[ts_key]
|
||||
|
||||
response = {}
|
||||
logger.debug(f"{histories_dict=}")
|
||||
# logger.debug(f"{histories_dict=}")
|
||||
for node in histories_dict:
|
||||
history = histories_dict[node]
|
||||
# future = futures_dict[node] # ground truth dict
|
||||
|
|
@ -509,9 +540,9 @@ class PredictionServer(Node):
|
|||
# default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
|
||||
|
||||
inference_parser.add_argument("--conf",
|
||||
help="path to json config file for hyperparameters, relative to model_dir",
|
||||
type=str,
|
||||
default='config.json')
|
||||
help="path to json config file for hyperparameters",
|
||||
type=pathlib.Path,
|
||||
default='EXPERIMENTS/config.json')
|
||||
|
||||
# Model Parameters (hyperparameters)
|
||||
inference_parser.add_argument("--offline_scene_graph",
|
||||
|
|
@ -566,12 +597,12 @@ class PredictionServer(Node):
|
|||
inference_parser.add_argument('--batch_size',
|
||||
help='training batch size',
|
||||
type=int,
|
||||
default=256)
|
||||
default=512)
|
||||
|
||||
inference_parser.add_argument('--k_eval',
|
||||
help='how many samples to take during evaluation',
|
||||
type=int,
|
||||
default=25)
|
||||
default=1)
|
||||
|
||||
# Data Parameters
|
||||
inference_parser.add_argument("--eval_data_dict",
|
||||
|
|
@ -593,7 +624,7 @@ class PredictionServer(Node):
|
|||
inference_parser.add_argument("--eval_device",
|
||||
help="what device to use during inference",
|
||||
type=str,
|
||||
default="cpu")
|
||||
default="cuda:0")
|
||||
|
||||
|
||||
inference_parser.add_argument('--seed',
|
||||
|
|
@ -634,6 +665,11 @@ class PredictionServer(Node):
|
|||
help="Center data around cx and cy. Should also be used when processing data",
|
||||
action='store_true')
|
||||
|
||||
inference_parser.add_argument('--cutoff-map',
|
||||
help='specify a map (svg-file) that specifies projection boundaries. In here, degrade chance to be selectede',
|
||||
type=str,
|
||||
default="../DATASETS/hof-lidar/map_hof.svg")
|
||||
|
||||
|
||||
return inference_parser
|
||||
|
||||
|
|
|
|||
|
|
@ -107,6 +107,8 @@ def process_data(src_dir: Path, dst_dir: Path, name: str, smooth_tracks: bool, n
|
|||
if not map_img_path.exists():
|
||||
raise RuntimeError(f"Map image does not exists {map_img_path}")
|
||||
|
||||
print(f"Using map {map_img_path}")
|
||||
|
||||
type_map = {}
|
||||
# TODO)) For now, assume the map is a 100x scale of the world coordinates (i.e. 100px per meter)
|
||||
# thus when we do a homography of 5px per meter, scale down by 20
|
||||
|
|
@ -358,7 +360,8 @@ def process_data(src_dir: Path, dst_dir: Path, name: str, smooth_tracks: bool, n
|
|||
--train_epochs 100 \\
|
||||
--conf {target_config} \\
|
||||
--data_dir {dst_dir} \\
|
||||
{"--map_encoding" if map_img_path else ""}
|
||||
{"--map_encoding" if map_img_path else ""} \\
|
||||
--no_edge_encoding
|
||||
""")
|
||||
|
||||
return names
|
||||
|
|
|
|||
|
|
@ -75,9 +75,9 @@ class SceneInfo:
|
|||
|
||||
class ScenarioScene(Enum):
|
||||
DETECTED = SceneInfo(4, "First detection")
|
||||
SUBSTANTIAL = SceneInfo(6, "Multiple detections")
|
||||
FIRST_PREDICTION = SceneInfo(10, "Prediction is ready")
|
||||
CORRECTED_PREDICTION = SceneInfo(11, "Multiple predictions")
|
||||
TRACKED = SceneInfo(6, "Multiple detections")
|
||||
PREDICTION_AVAILABLE = SceneInfo(10, "Prediction is ready")
|
||||
UPDATED_PREDICTION = SceneInfo(11, "Multiple predictions")
|
||||
LOITERING = SceneInfo(7, "Foundto be loitering", takeover_possible=True, takeover_possible_after=10) # TODO: create "possible after"
|
||||
PLAY = SceneInfo(7, description="After many predictions; just fooling around", takeover_possible=True, takeover_possible_after=10)
|
||||
LOST = SceneInfo(-1, description="Track lost", takeover_possible=True, takeover_possible_after=0)
|
||||
|
|
@ -166,6 +166,7 @@ class Scenario(PrioritySlotItem):
|
|||
def get_priority(self) -> int:
|
||||
# newer higher prio
|
||||
distance = 0
|
||||
# todo: check if last point is within bounds
|
||||
if self.track and len(self.track.projected_history) > 5:
|
||||
distance = np.linalg.norm(self.track.projected_history[-1] - self.track.projected_history[0])
|
||||
return (self.scene.value.priority, distance)
|
||||
|
|
@ -261,17 +262,17 @@ class Scenario(PrioritySlotItem):
|
|||
def check_track(self):
|
||||
predictions = len(self.prediction_tracks)
|
||||
if predictions and self.running_for() < 20:
|
||||
self.set_scene(ScenarioScene.FIRST_PREDICTION)
|
||||
self.set_scene(ScenarioScene.PREDICTION_AVAILABLE)
|
||||
return True
|
||||
if predictions and self.running_for() > 120:
|
||||
if predictions and self.running_for() > 60 * 5:
|
||||
self.set_scene(ScenarioScene.PLAY)
|
||||
return True
|
||||
if predictions:
|
||||
self.set_scene(ScenarioScene.CORRECTED_PREDICTION)
|
||||
self.set_scene(ScenarioScene.UPDATED_PREDICTION)
|
||||
return True
|
||||
if self.track:
|
||||
if len(self.track.projected_history) > TRACK_ASSUMED_FPS * 3:
|
||||
self.set_scene(ScenarioScene.SUBSTANTIAL)
|
||||
if len(self.track.projected_history) > TRACK_ASSUMED_FPS * 2:
|
||||
self.set_scene(ScenarioScene.TRACKED)
|
||||
else:
|
||||
self.set_scene(ScenarioScene.DETECTED)
|
||||
return True
|
||||
|
|
@ -360,7 +361,7 @@ class DrawnScenario(Scenario):
|
|||
self.line_history = LineAnimationStack(history)
|
||||
self.line_history.add(AppendableLineAnimator(self.line_history.tail, draw_decay_speed=120, transition_in_on_init=False))
|
||||
self.line_history.add(CropLine(self.line_history.tail, self.MAX_HISTORY))
|
||||
self.line_history.add(SimplifyLine(self.line_history.tail, 0.003)) # Simplify before effects, so they don't distort
|
||||
self.line_history.add(SimplifyLine(self.line_history.tail, 0.002)) # Simplify before effects, so they don't distort
|
||||
self.line_history.add(FadedTailLine(self.line_history.tail, TRACK_FADE_AFTER_DURATION * TRACK_ASSUMED_FPS, TRACK_END_FADE))
|
||||
self.line_history.add(NoiseLine(self.line_history.tail, amplitude=0, t_factor=.3))
|
||||
self.line_history.add(FadeOutJitterLine(self.line_history.tail, frequency=5, t_factor=.5))
|
||||
|
|
@ -373,9 +374,9 @@ class DrawnScenario(Scenario):
|
|||
self.line_prediction.get(StartFromClosestPoint).skip=True
|
||||
self.line_prediction.add(RotatingLine(self.line_prediction.tail, decay_speed=16))
|
||||
self.line_prediction.get(RotatingLine).skip = False
|
||||
self.line_prediction.add(SegmentLine(self.line_prediction.tail, duration=7, anim_f=SegmentLine.anim_follow_in_front))
|
||||
self.line_prediction.add(SegmentLine(self.line_prediction.tail, duration=7 / 3, anim_f=SegmentLine.anim_follow_in_front))
|
||||
self.line_prediction.get(SegmentLine).skip = False
|
||||
self.line_prediction.add(SimplifyLine(self.line_prediction.tail, 0.003)) # Simplify before effects, so they don't distort
|
||||
self.line_prediction.add(SimplifyLine(self.line_prediction.tail, 0.002)) # Simplify before effects, so they don't distort
|
||||
GAP_DURATION = 5
|
||||
def dash_len(dt, t):
|
||||
t=min(1, t/GAP_DURATION)
|
||||
|
|
@ -573,7 +574,10 @@ class DrawnScenario(Scenario):
|
|||
original = self.scene.name
|
||||
changed = super().set_scene(scene)
|
||||
if changed:
|
||||
self.stage.log_sock.send_string(f"Change {self.track_id}: {original} -> {self.scene.name}")
|
||||
try:
|
||||
self.stage.log_sock.send_string(f"Visitor {self.track_id}: {original} -> {self.scene.name}", zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
logger.warning("Not sent the scene change message, broken socket?")
|
||||
return changed
|
||||
|
||||
class NoTracksScenario(PrioritySlotItem):
|
||||
|
|
@ -622,12 +626,23 @@ class DebugDrawer():
|
|||
def __init__(self, stage: Stage):
|
||||
self.stage = stage
|
||||
|
||||
def to_renderable_lines(self, dt: DeltaT):
|
||||
def positions_to_renderable_lines(self, dt: DeltaT):
|
||||
lines = RenderableLines([], CoordinateSpace.WORLD)
|
||||
past_color = SrgbaColor(1,0,1,1)
|
||||
future_color = SrgbaColor(0,1,0,1)
|
||||
current_color = SrgbaColor(1,0,0,.6)
|
||||
for scenario in self.stage.scenarios.values():
|
||||
lines.append(StaticLine(scenario.track.projected_history, past_color).as_renderable_line(dt))
|
||||
# lines.append(StaticLine(scenario.track.projected_history, past_color).as_renderable_line(dt).as_simplified(factor=.005))
|
||||
center = scenario.track.projected_history[-1]
|
||||
|
||||
lines.append(StaticLine([[center[0], center[1]-.2], [center[0], center[1]+.2]], current_color).as_renderable_line(dt))
|
||||
lines.append(StaticLine([[center[0]-.2, center[1]], [center[0]+.2, center[1]]], current_color).as_renderable_line(dt))
|
||||
return lines
|
||||
|
||||
def predictions_to_renderable_lines(self, dt: DeltaT):
|
||||
lines = RenderableLines([], CoordinateSpace.WORLD)
|
||||
future_color = SrgbaColor(0,1,0,.6)
|
||||
for scenario in self.stage.scenarios.values():
|
||||
# lines.append(StaticLine(scenario.track.projected_history, past_color).as_renderable_line(dt).as_simplified(factor=.005))
|
||||
if scenario.active_ptrack:
|
||||
lines.append(StaticLine(scenario.active_ptrack._track.predictions[0], future_color).as_renderable_line(dt))
|
||||
return lines
|
||||
|
|
@ -639,7 +654,8 @@ class DatasetDrawer():
|
|||
|
||||
line_color = SrgbaColor(0,1,1,1)
|
||||
self.track_line = LineAnimationStack(StaticLine([], line_color))
|
||||
self.track_line.add(SimplifyLine(self.track_line.tail, 0.004)) # Simplify before cropping, to get less noodling
|
||||
# self.track_line.add(SimplifyLine(self.track_line.tail, 0.004)) # Simplify before cropping, to get less noodling
|
||||
self.track_line.add(SimplifyLine(self.track_line.tail, 0.002)) # no laser in dortmund
|
||||
self.track_line.add(CropAnimationLine(self.track_line.tail, 50, assume_fps=TRACK_ASSUMED_FPS*20)) # speed up
|
||||
|
||||
# self.track_line.add(DashedLine(self.track_line.tail, t_factor=4, loop_offset=True))
|
||||
|
|
@ -653,11 +669,14 @@ class DatasetDrawer():
|
|||
def to_renderable_lines(self, dt: DeltaT):
|
||||
lines = RenderableLines([], CoordinateSpace.WORLD)
|
||||
if not self.track_line.is_running():
|
||||
# print('update')
|
||||
track_id = random.choice(list(self.stage.history.state.tracks.keys()))
|
||||
# print('track_id', track_id)
|
||||
positions = self.stage.history.state.track_histories[track_id]
|
||||
self.track_line.root.points = positions
|
||||
self.track_line.start()
|
||||
# else:
|
||||
# print('-')
|
||||
|
||||
lines.lines.append(
|
||||
self.track_line.as_renderable_line(dt)
|
||||
|
|
@ -795,9 +814,11 @@ class Stage(Node):
|
|||
# TODO: sometimes very slow!
|
||||
t1 = time.perf_counter()
|
||||
training_lines = self.auxilary.to_renderable_lines(dt)
|
||||
all_active_tracks = self.debug_drawer.to_renderable_lines(dt)
|
||||
|
||||
t2 = time.perf_counter()
|
||||
active_positions = self.debug_drawer.positions_to_renderable_lines(dt)
|
||||
all_predictions = self.debug_drawer.predictions_to_renderable_lines(dt)
|
||||
|
||||
t2b = time.perf_counter()
|
||||
|
||||
timings = []
|
||||
for scenario in self.active_scenarios:
|
||||
|
|
@ -807,7 +828,7 @@ class Stage(Node):
|
|||
if not len(self.active_scenarios):
|
||||
lines = training_lines
|
||||
|
||||
t2b = time.perf_counter()
|
||||
t2c = time.perf_counter()
|
||||
# rl_scenario = lines.as_simplified(SimplifyMethod.RDP, .003) # or segmentise (see shapely)
|
||||
# rl_training = training_lines.as_simplified(SimplifyMethod.RDP, .003) # or segmentise (see shapely)
|
||||
self.counter.set("stage.lines", len(lines.lines))
|
||||
|
|
@ -820,7 +841,8 @@ class Stage(Node):
|
|||
1: lines,
|
||||
2: self.debug_lines,
|
||||
3: training_lines,
|
||||
4: all_active_tracks,
|
||||
4: active_positions,
|
||||
5: all_predictions,
|
||||
}
|
||||
|
||||
t4 = time.perf_counter()
|
||||
|
|
@ -836,7 +858,7 @@ class Stage(Node):
|
|||
|
||||
t6 = time.perf_counter()
|
||||
|
||||
t = (t2-t1, t3-t2b, t2b-t2, t4-t3, t5-t4, t6-t5)
|
||||
t = (t2-t1, t2b-t2, t2c-t2b, t3-t2c, t2b-t2, t4-t3, t5-t4, t6-t5)
|
||||
if sum(t) > .1:
|
||||
print(t)
|
||||
print(len(lines.lines))
|
||||
|
|
@ -883,6 +905,10 @@ class Stage(Node):
|
|||
help='specify a map (svg-file) from which to load lines which will be overlayed',
|
||||
type=str,
|
||||
default="../DATASETS/hof-lidar/map_hof.svg")
|
||||
argparser.add_argument('--cutoff-map',
|
||||
help='specify a map (svg-file) that specifies projection boundaries. In here, degrade chance to be selectede',
|
||||
type=str,
|
||||
default="../DATASETS/hof-lidar/map_hof.svg")
|
||||
argparser.add_argument('--max-active-scenarios',
|
||||
help='Maximum number of active scenarios that can be drawn at once (to not overlod the laser)',
|
||||
type=int,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from argparse import ArgumentParser
|
||||
from collections import deque
|
||||
import math
|
||||
import re
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pyglet
|
||||
|
|
@ -11,7 +12,7 @@ import zmq
|
|||
from trap.lines import RenderableLayers, message_to_layers
|
||||
from trap.node import Node
|
||||
|
||||
BG_COLOR = (0,0,0)
|
||||
BG_COLOR = (0,0,255)
|
||||
class StageRenderer(Node):
|
||||
def setup(self):
|
||||
# self.prediction_sock = self.sub(self.config.zmq_prediction_addr)
|
||||
|
|
@ -49,6 +50,9 @@ class StageRenderer(Node):
|
|||
# self.window.set_size(1080, 1920)
|
||||
|
||||
window_size = self.window.get_size()
|
||||
|
||||
padding = 40
|
||||
|
||||
print(window_size)
|
||||
self.window.set_handler('on_draw', self.on_draw)
|
||||
# self.window.set_handler('on_close', self.on_close)
|
||||
|
|
@ -68,14 +72,14 @@ class StageRenderer(Node):
|
|||
self.text = pyglet.text.document.FormattedDocument("")
|
||||
self.text_batch = pyglet.graphics.Batch()
|
||||
self.text_layout = pyglet.text.layout.TextLayout(
|
||||
self.text, 20, 350,
|
||||
width=self.window.get_size()[1],
|
||||
height=self.window.get_size()[0] // 3,
|
||||
self.text, padding, (self.window.get_size()[0]-padding*2) // 2 - 100,
|
||||
width=self.window.get_size()[1] - 2*padding,
|
||||
height=(self.window.get_size()[0] - padding) // 2,
|
||||
multiline=True, wrap_lines=False, batch=self.text_batch)
|
||||
|
||||
max_len = 30
|
||||
max_len = 31
|
||||
self.log_msgs = deque([], maxlen=max_len)
|
||||
self.log_msgs.extend(["..."] * max_len)
|
||||
self.log_msgs.extend(["-"] * max_len)
|
||||
|
||||
|
||||
translate = (10,-400)
|
||||
|
|
@ -86,7 +90,6 @@ class StageRenderer(Node):
|
|||
max_y = 14.3
|
||||
scale = min(smallest_dimension / max_x, smallest_dimension/max_y)
|
||||
|
||||
padding = 40
|
||||
|
||||
self.logger.info(f"Use {scale=}")
|
||||
|
||||
|
|
@ -111,6 +114,8 @@ class StageRenderer(Node):
|
|||
self.clear_transparent = pyglet.shapes.Rectangle(0, window_size[1]-clear_area, window_size[0], clear_area, color=(*BG_COLOR,255//70))
|
||||
self.clear_fully= pyglet.shapes.Rectangle(0, 0, window_size[0], window_size[1]-clear_area, color=(*BG_COLOR,255))
|
||||
|
||||
self.window.clear()
|
||||
|
||||
|
||||
def check_running(self, dt):
|
||||
if not self.run_loop():
|
||||
|
|
@ -178,25 +183,23 @@ class StageRenderer(Node):
|
|||
color = (p2.color.as_array()*255).astype(int)
|
||||
|
||||
if i < len(self.lines):
|
||||
print('reuse')
|
||||
shape = self.lines[i]
|
||||
shape.x = pos1[0]
|
||||
shape.y = pos1[1]
|
||||
shape.x2 = pos2[0]
|
||||
shape.y2 = pos2[1]
|
||||
shape.color = color
|
||||
else:
|
||||
self.lines.append(pyglet.shapes.Line(pos1[0], pos1[1],
|
||||
pos2[0],
|
||||
pos2[1],
|
||||
3,
|
||||
color,
|
||||
batch=self.lines_batch))
|
||||
|
||||
self.lines.append(pyglet.shapes.Line(pos1[0], pos1[1],
|
||||
pos2[0],
|
||||
pos2[1],
|
||||
3,
|
||||
color,
|
||||
batch=self.lines_batch))
|
||||
|
||||
print(len(self.lines), i)
|
||||
too_many = len(self.lines) - 1 - i
|
||||
if too_many > 0:
|
||||
print('del', too_many)
|
||||
for j in reversed(range(i, i+too_many)):
|
||||
self.lines[i].delete()
|
||||
del self.lines[i]
|
||||
|
|
@ -212,18 +215,36 @@ class StageRenderer(Node):
|
|||
))
|
||||
|
||||
|
||||
colorsmap = {
|
||||
'ANOMALOUS': (255, 0, 0, 255),
|
||||
'LOITERING': (255, 255, 0, 255),
|
||||
'DETECTED': (255, 0, 255, 255),
|
||||
'SUBSTANTIAL': (255, 0, 255, 255),
|
||||
'LOST': (0, 0, 0, 255),
|
||||
}
|
||||
|
||||
matchtext = "".join(self.log_msgs) # find no newlines
|
||||
for state,color in colorsmap.items():
|
||||
for match in re.finditer(state, matchtext):
|
||||
self.text.set_style(match.start(), match.end(), dict(
|
||||
color=color
|
||||
))
|
||||
|
||||
|
||||
|
||||
|
||||
def on_draw(self):
|
||||
self.receive(.1)
|
||||
self.window.clear()
|
||||
# self.clear_transparent.color = (*BG_COLOR, int(255*self.get_setting('stagerenderer.fade', .27)))
|
||||
# self.clear_transparent.draw()
|
||||
# self.clear_fully.draw()
|
||||
# self.window.clear()
|
||||
self.clear_transparent.color = (*BG_COLOR, int(3))
|
||||
self.clear_transparent.draw()
|
||||
self.clear_fully.draw()
|
||||
self.fps_display.draw()
|
||||
|
||||
# self.bg_sprite.draw()
|
||||
self.bg_sprite.draw()
|
||||
|
||||
# self.lines_batch.draw()
|
||||
# self.text_batch.draw()
|
||||
self.lines_batch.draw()
|
||||
self.text_batch.draw()
|
||||
|
||||
|
||||
@classmethod
|
||||
|
|
@ -248,8 +269,8 @@ class StageRenderer(Node):
|
|||
type=str,
|
||||
default="SETTINGS/2025-11-dortmund/space/floorplan.png")
|
||||
render_parser.add_argument('--monitor',
|
||||
help='Specify a screen on which to output (eg. HDMI-1)',
|
||||
help='Specify a screen on which to output (eg. HDMI-0)',
|
||||
type=str,
|
||||
default="HDMI-1")
|
||||
default="HDMI-0")
|
||||
return render_parser
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,20 @@ from pathlib import Path
|
|||
|
||||
import zmq
|
||||
|
||||
from trap.base import Track
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.node import Node
|
||||
from trap.tracker import TrainingDataWriter
|
||||
from trap.tracker import TrainingDataWriter, TrainingTrackWriter
|
||||
|
||||
|
||||
class TrackWriter(Node):
|
||||
def setup(self):
|
||||
self.track_sock = self.sub(self.config.zmq_trajectory_addr)
|
||||
self.track_sock = self.sub(self.config.zmq_lost_addr)
|
||||
self.log_sock = self.push(self.config.zmq_log_addr)
|
||||
|
||||
|
||||
def run(self):
|
||||
with TrainingDataWriter(self.config.output_dir) as writer:
|
||||
with TrainingTrackWriter(self.config.output_dir) as writer:
|
||||
try:
|
||||
while self.run_loop():
|
||||
zmq_ev = self.track_sock.poll(timeout=1000)
|
||||
|
|
@ -26,10 +28,20 @@ class TrackWriter(Node):
|
|||
continue
|
||||
|
||||
try:
|
||||
frame: Frame = self.track_sock.recv_pyobj()
|
||||
writer.add(frame, frame.tracks.values())
|
||||
track: Track = self.track_sock.recv_pyobj()
|
||||
|
||||
self.logger.debug(f"write frame {frame.time:.3f} with {len(frame.tracks)} tracks")
|
||||
if len(track.history) < 20:
|
||||
self.logger.debug(f"ignore short track {len(track.history)}")
|
||||
continue
|
||||
|
||||
writer.add(track)
|
||||
|
||||
self.logger.info(f"Added track {track.track_id}")
|
||||
|
||||
try:
|
||||
self.log_sock.send_string(f"Added track {track.track_id} to dataset, {len(track.history)} datapoints", zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
self.logger.warning("Not sent the message, broken socket?")
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
|
||||
|
|
@ -44,10 +56,14 @@ class TrackWriter(Node):
|
|||
@classmethod
|
||||
def arg_parser(cls):
|
||||
argparser = ArgumentParser()
|
||||
argparser.add_argument('--zmq-trajectory-addr',
|
||||
argparser.add_argument('--zmq-log-addr',
|
||||
help='Manually specity communication addr for the log messages',
|
||||
type=str,
|
||||
default="tcp://0.0.0.0:99188")
|
||||
argparser.add_argument('--zmq-lost-addr',
|
||||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
default="ipc:///tmp/feeds_lost")
|
||||
argparser.add_argument("--output-dir",
|
||||
help="Directory to save the video in",
|
||||
required=True,
|
||||
|
|
|
|||
|
|
@ -127,13 +127,36 @@ class TrackReader:
|
|||
def __init__(self, path: Path, fps: int, include_blacklisted = False, exclude_whitelisted = False):
|
||||
self.blacklist_file = path / "blacklist.jsonl"
|
||||
self.whitelist_file = path / "whitelist.jsonl" # for skipping
|
||||
self.tracks_file = path / "tracks.pkl"
|
||||
# self.tracks_file = path / "tracks.pkl"
|
||||
self.tracks_files = path.glob('tracks*.pkl')
|
||||
|
||||
# with self.tracks_file.open('r') as fp:
|
||||
# tracks_dict: dict = json.load(fp)
|
||||
|
||||
with self.tracks_file.open('rb') as fp:
|
||||
tracks: dict = pickle.load(fp)
|
||||
tracks: Dict[str, Track] = {}
|
||||
for tracks_file in self.tracks_files:
|
||||
logger.info(f"Read {tracks_file}")
|
||||
with tracks_file.open('rb') as fp:
|
||||
while True:
|
||||
# multiple tracks can be pickled separately
|
||||
try:
|
||||
trackset: Dict[str, Track] = pickle.load(fp)
|
||||
for track_id, track in trackset.items():
|
||||
if len(tracks) < 1:
|
||||
max_item = 0
|
||||
else:
|
||||
max_item = max([int(t) for t in tracks.keys()])
|
||||
|
||||
if int(track.track_id) < max_item:
|
||||
track_id = str(max_item+1)
|
||||
else:
|
||||
track_id = track.track_id
|
||||
|
||||
track.track_id = track_id
|
||||
tracks[track.track_id] = track
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
|
||||
if self.blacklist_file.exists():
|
||||
|
|
@ -257,6 +280,48 @@ class TrainingDataWriter:
|
|||
rewrite_raw_track_files(self.path)
|
||||
|
||||
|
||||
class TrainingTrackWriter:
|
||||
"""
|
||||
Supersedes TrainingDataWriter, by writing full tracks"""
|
||||
def __init__(self, training_path: Optional[Path]):
|
||||
if training_path is None:
|
||||
self.path = None
|
||||
return
|
||||
|
||||
if not isinstance(training_path, Path):
|
||||
raise ValueError("save-for-training should be a path")
|
||||
if not training_path.exists():
|
||||
logger.info(f"Making path for training data: {training_path}")
|
||||
training_path.mkdir(parents=True, exist_ok=False)
|
||||
else:
|
||||
logger.warning(f"Path for training-data exists: {training_path}. Continuing assuming that's ok.")
|
||||
|
||||
self.path = training_path
|
||||
|
||||
def __enter__(self):
|
||||
if self.path:
|
||||
d = datetime.now().isoformat(timespec="minutes")
|
||||
self.training_fp = open(self.path / f'tracks-{d}.pcl', 'wb')
|
||||
logger.debug(f"Writing tracker data to {self.training_fp.name}")
|
||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
# self.csv = csv.DictWriter(self.training_fp, fieldnames=FIELDNAMES, delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def add(self, track: Track):
|
||||
self.count += 1;
|
||||
pickle.dump(track, self.training_fp)
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
# ... ignore exception (type, value, traceback)
|
||||
if not self.path:
|
||||
return
|
||||
|
||||
self.training_fp.close()
|
||||
# rewrite_raw_track_files(self.path)
|
||||
|
||||
|
||||
|
||||
def rewrite_raw_track_files(path: Path):
|
||||
source_files = list(sorted(path.glob("*.txt"))) # we loop twice, so need a list instead of generator
|
||||
|
|
|
|||
Loading…
Reference in a new issue