From 06c5bde9978f807b5cfc1c16bf0e69bf15ed3012 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Fri, 20 Oct 2023 13:27:51 +0200 Subject: [PATCH] Option to save tracker output as training data --- trap/tracker.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/trap/tracker.py b/trap/tracker.py index 658e245..7b0a869 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -1,7 +1,9 @@ from argparse import Namespace +import csv import json import logging from multiprocessing import Event +from pathlib import Path import pickle import time import numpy as np @@ -18,6 +20,12 @@ from trap.frame_emitter import Frame Detection = [int, int, int, int, float, int] Detections = [Detection] +# This is the dt that is also used by the scene. +# as this needs to be rather stable, try to adhere +# to it by waiting when we are faster. Value chosen based +# on a rough estimate of tracker duration +TARGET_DT = .1 + logger = logging.getLogger("trap.tracker") class Tracker: @@ -56,11 +64,36 @@ class Tracker: def track(self): + prev_run_time = 0 + + training_fp = None + training_csv = None + training_frames = 0 + + if self.config.save_for_training is not None: + if not isinstance(self.config.save_for_training, Path): + raise ValueError("save-for-training should be a path") + if not self.config.save_for_training.exists(): + logger.info(f"Making path for training data: {self.config.save_for_training}") + self.config.save_for_training.mkdir(parents=True, exist_ok=False) + else: + logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.") + training_fp = open(self.config.save_for_training / 'all.txt', 'w') + # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py + training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE) + + frame_i = 0 while self.is_running.is_set(): + this_run_time = time.time() + # logger.debug(f'test {prev_run_time - this_run_time}') + time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) + prev_run_time = time.time() + msg = self.frame_sock.recv() frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s # logger.info(f"Frame delivery delay = {time.time()-frame.time}s") start_time = time.time() + detections = self.detect_persons(frame.img) tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img) @@ -91,6 +124,34 @@ class Tracker: #TODO calculate fps (also for other loops to see asynchonity) # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display + if training_csv: + training_csv.writerows([{ + 'frame_id': round(frame_i * 10., 1), # not really time + 'track_id': t['id'], + 'x': t['history'][-1]['x'], + 'y': t['history'][-1]['y'], + } for t in trajectories.values()]) + training_frames += len(trajectories) + frame_i += 1 + + if training_fp: + training_fp.close() + lines = { + 'train': int(training_frames * .8), + 'val': int(training_frames * .12), + 'test': int(training_frames * .08), + } + logger.info(f"Splitting gathered data from {training_fp.name}") + with open(training_fp.name, 'r') as source_fp: + for name, line_nrs in lines.items(): + dir_path = self.config.save_for_training / name + dir_path.mkdir(exist_ok=True) + file = dir_path / 'tracked.txt' + logger.debug(f"- Write {line_nrs} lines to {file}") + with file.open('w') as target_fp: + for i in range(line_nrs): + target_fp.write(source_fp.readline()) + logger.info('Stopping')