Frame emitter loops over files in folder

This commit is contained in:
Ruben van de Ven 2023-10-20 13:24:43 +02:00
parent 3a64c438eb
commit c903d07b49
2 changed files with 80 additions and 31 deletions

View file

@ -1,9 +1,25 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import types
from pyparsing import Optional from pyparsing import Optional
parser = argparse.ArgumentParser() class LambdaParser(argparse.ArgumentParser):
"""Execute lambda functions
"""
def parse_args(self, args=None, namespace=None):
args = super().parse_args(args, namespace)
for key in vars(args):
f = args.__dict__[key]
if type(f) == types.LambdaType:
print(f'Getting default value for {key}')
args.__dict__[key] = f()
return args
parser = LambdaParser()
# parser.parse_args()
parser.add_argument( parser.add_argument(
@ -34,7 +50,8 @@ render_parser = parser.add_argument_group('Renderer')
inference_parser.add_argument("--model_dir", inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference", help="directory with the model to use for inference",
type=str, # TODO: make into Path type=str, # TODO: make into Path
default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3') default='../Trajectron-plus-plus/experiments/trap/models/models_18_Oct_2023_19_56_22_virat_vel_ar3/')
# default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
inference_parser.add_argument("--conf", inference_parser.add_argument("--conf",
help="path to json config file for hyperparameters, relative to model_dir", help="path to json config file for hyperparameters, relative to model_dir",
@ -129,6 +146,10 @@ inference_parser.add_argument('--seed',
type=int, type=int,
default=123) default=123)
inference_parser.add_argument('--predict_training_data',
help='Ignore tracker and predict data from the training dataset',
action='store_true')
# Internal connections. # Internal connections.
@ -167,7 +188,13 @@ connection_parser.add_argument('--bypass-prediction',
frame_emitter_parser.add_argument("--video-src", frame_emitter_parser.add_argument("--video-src",
help="source video to track from", help="source video to track from",
type=Path, type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4') nargs='+',
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
#TODO: camera as source
frame_emitter_parser.add_argument("--video-no-loop",
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.",
action='store_true')
#TODO: camera as source #TODO: camera as source
@ -177,11 +204,14 @@ tracker_parser.add_argument("--homography",
help="File with homography params", help="File with homography params",
type=Path, type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
tracker_parser.add_argument("--save-for-training",
help="Specify the path in which to save",
type=Path,
default=None)
# Renderer # Renderer
# render_parser.add_argument("--output-dir", render_parser.add_argument("--render-preview",
# help="Target image dir", help="Render a video file previewing the prediction, and its delay compared to the current frame",
# type=Optional[Path], action='store_true')
# default=None)

View file

@ -1,5 +1,6 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import cycle
import logging import logging
from multiprocessing import Event from multiprocessing import Event
import pickle import pickle
@ -28,13 +29,23 @@ class FrameEmitter:
self.is_running = is_running self.is_running = is_running
context = zmq.Context() context = zmq.Context()
# TODO: to make things faster, a multiprocessing.Array might be a tad faster: https://stackoverflow.com/a/65201859
self.frame_sock = context.socket(zmq.PUB) self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr) self.frame_sock.bind(config.zmq_frame_addr)
logger.info(f"Connection socket {config.zmq_frame_addr}") logger.info(f"Connection socket {config.zmq_frame_addr}")
if self.config.video_no_loop:
self.video_srcs = self.config.video_src
else:
self.video_srcs = cycle(self.config.video_src)
def emit_video(self): def emit_video(self):
video = cv2.VideoCapture(str(self.config.video_src)) for video_path in self.video_srcs:
logger.info(f"Play from '{str(video_path)}'")
video = cv2.VideoCapture(str(video_path))
fps = video.get(cv2.CAP_PROP_FPS) fps = video.get(cv2.CAP_PROP_FPS)
frame_duration = 1./fps frame_duration = 1./fps
logger.info(f"Emit frames at {fps} fps") logger.info(f"Emit frames at {fps} fps")
@ -45,12 +56,14 @@ class FrameEmitter:
# seek to 0 if video has finished. Infinite loop # seek to 0 if video has finished. Infinite loop
if not ret: if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0) # now loading multiple files
ret, img = video.read() break
assert ret is not False # not really error proof... # video.set(cv2.CAP_PROP_POS_FRAMES, 0)
# ret, img = video.read()
# assert ret is not False # not really error proof...
frame = Frame(img=img) frame = Frame(img=img)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue? # perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))
# defer next loop # defer next loop
@ -62,6 +75,12 @@ class FrameEmitter:
else: else:
prev_time = new_frame_time prev_time = new_frame_time
if not self.is_running.is_set():
# if not running, also break out of infinite generator loop
break
logger.info("Stopping")
def run_frame_emitter(config: Namespace, is_running: Event): def run_frame_emitter(config: Namespace, is_running: Event):