diff --git a/trap/config.py b/trap/config.py index 8471c04..a92e8bc 100644 --- a/trap/config.py +++ b/trap/config.py @@ -1,9 +1,25 @@ import argparse from pathlib import Path +import types from pyparsing import Optional -parser = argparse.ArgumentParser() +class LambdaParser(argparse.ArgumentParser): + """Execute lambda functions + """ + def parse_args(self, args=None, namespace=None): + args = super().parse_args(args, namespace) + + for key in vars(args): + f = args.__dict__[key] + if type(f) == types.LambdaType: + print(f'Getting default value for {key}') + args.__dict__[key] = f() + + return args + +parser = LambdaParser() +# parser.parse_args() parser.add_argument( @@ -34,7 +50,8 @@ render_parser = parser.add_argument_group('Renderer') inference_parser.add_argument("--model_dir", help="directory with the model to use for inference", type=str, # TODO: make into Path - default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3') + default='../Trajectron-plus-plus/experiments/trap/models/models_18_Oct_2023_19_56_22_virat_vel_ar3/') + # default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3') inference_parser.add_argument("--conf", help="path to json config file for hyperparameters, relative to model_dir", @@ -129,6 +146,10 @@ inference_parser.add_argument('--seed', type=int, default=123) +inference_parser.add_argument('--predict_training_data', + help='Ignore tracker and predict data from the training dataset', + action='store_true') + # Internal connections. @@ -167,7 +188,13 @@ connection_parser.add_argument('--bypass-prediction', frame_emitter_parser.add_argument("--video-src", help="source video to track from", type=Path, - default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4') + nargs='+', + default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4'))) +#TODO: camera as source + +frame_emitter_parser.add_argument("--video-no-loop", + help="By default it emitter will run indefiniately. This prevents that and plays every video only once.", + action='store_true') #TODO: camera as source @@ -177,11 +204,14 @@ tracker_parser.add_argument("--homography", help="File with homography params", type=Path, default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') +tracker_parser.add_argument("--save-for-training", + help="Specify the path in which to save", + type=Path, + default=None) # Renderer -# render_parser.add_argument("--output-dir", -# help="Target image dir", -# type=Optional[Path], -# default=None) +render_parser.add_argument("--render-preview", + help="Render a video file previewing the prediction, and its delay compared to the current frame", + action='store_true') diff --git a/trap/frame_emitter.py b/trap/frame_emitter.py index dcd477c..832e302 100644 --- a/trap/frame_emitter.py +++ b/trap/frame_emitter.py @@ -1,5 +1,6 @@ from argparse import Namespace from dataclasses import dataclass, field +from itertools import cycle import logging from multiprocessing import Event import pickle @@ -28,39 +29,57 @@ class FrameEmitter: self.is_running = is_running context = zmq.Context() + # TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859 self.frame_sock = context.socket(zmq.PUB) self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind self.frame_sock.bind(config.zmq_frame_addr) + logger.info(f"Connection socket {config.zmq_frame_addr}") + if self.config.video_no_loop: + self.video_srcs = self.config.video_src + else: + self.video_srcs = cycle(self.config.video_src) + + def emit_video(self): - video = cv2.VideoCapture(str(self.config.video_src)) - fps = video.get(cv2.CAP_PROP_FPS) - frame_duration = 1./fps - logger.info(f"Emit frames at {fps} fps") + for video_path in self.video_srcs: + logger.info(f"Play from '{str(video_path)}'") + video = cv2.VideoCapture(str(video_path)) + fps = video.get(cv2.CAP_PROP_FPS) + frame_duration = 1./fps + logger.info(f"Emit frames at {fps} fps") - prev_time = time.time() - while self.is_running.is_set(): - ret, img = video.read() - - # seek to 0 if video has finished. Infinite loop - if not ret: - video.set(cv2.CAP_PROP_POS_FRAMES, 0) + prev_time = time.time() + while self.is_running.is_set(): ret, img = video.read() - assert ret is not False # not really error proof... - frame = Frame(img=img) - # TODO: this is very dirty, need to find another way. - # perhaps multiprocessing queue? - self.frame_sock.send(pickle.dumps(frame)) - # defer next loop - new_frame_time = time.time() - time_diff = (new_frame_time - prev_time) - if time_diff < frame_duration: - time.sleep(frame_duration - time_diff) - new_frame_time += frame_duration - time_diff - else: - prev_time = new_frame_time + # seek to 0 if video has finished. Infinite loop + if not ret: + # now loading multiple files + break + # video.set(cv2.CAP_PROP_POS_FRAMES, 0) + # ret, img = video.read() + # assert ret is not False # not really error proof... + frame = Frame(img=img) + # TODO: this is very dirty, need to find another way. + # perhaps multiprocessing Array? + self.frame_sock.send(pickle.dumps(frame)) + + # defer next loop + new_frame_time = time.time() + time_diff = (new_frame_time - prev_time) + if time_diff < frame_duration: + time.sleep(frame_duration - time_diff) + new_frame_time += frame_duration - time_diff + else: + prev_time = new_frame_time + + if not self.is_running.is_set(): + # if not running, also break out of infinite generator loop + break + + logger.info("Stopping")