Option to save tracker output as training data

This commit is contained in:
Ruben van de Ven 2023-10-20 13:27:51 +02:00
parent 2171dd459a
commit 06c5bde997

View file

@ -1,7 +1,9 @@
from argparse import Namespace
import csv
import json
import logging
from multiprocessing import Event
from pathlib import Path
import pickle
import time
import numpy as np
@ -18,6 +20,12 @@ from trap.frame_emitter import Frame
Detection = [int, int, int, int, float, int]
Detections = [Detection]
# This is the dt that is also used by the scene.
# as this needs to be rather stable, try to adhere
# to it by waiting when we are faster. Value chosen based
# on a rough estimate of tracker duration
TARGET_DT = .1
logger = logging.getLogger("trap.tracker")
class Tracker:
@ -56,11 +64,36 @@ class Tracker:
def track(self):
prev_run_time = 0
training_fp = None
training_csv = None
training_frames = 0
if self.config.save_for_training is not None:
if not isinstance(self.config.save_for_training, Path):
raise ValueError("save-for-training should be a path")
if not self.config.save_for_training.exists():
logger.info(f"Making path for training data: {self.config.save_for_training}")
self.config.save_for_training.mkdir(parents=True, exist_ok=False)
else:
logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
training_fp = open(self.config.save_for_training / 'all.txt', 'w')
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
frame_i = 0
while self.is_running.is_set():
this_run_time = time.time()
# logger.debug(f'test {prev_run_time - this_run_time}')
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
prev_run_time = time.time()
msg = self.frame_sock.recv()
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
start_time = time.time()
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
@ -91,6 +124,34 @@ class Tracker:
#TODO calculate fps (also for other loops to see asynchonity)
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
if training_csv:
training_csv.writerows([{
'frame_id': round(frame_i * 10., 1), # not really time
'track_id': t['id'],
'x': t['history'][-1]['x'],
'y': t['history'][-1]['y'],
} for t in trajectories.values()])
training_frames += len(trajectories)
frame_i += 1
if training_fp:
training_fp.close()
lines = {
'train': int(training_frames * .8),
'val': int(training_frames * .12),
'test': int(training_frames * .08),
}
logger.info(f"Splitting gathered data from {training_fp.name}")
with open(training_fp.name, 'r') as source_fp:
for name, line_nrs in lines.items():
dir_path = self.config.save_for_training / name
dir_path.mkdir(exist_ok=True)
file = dir_path / 'tracked.txt'
logger.debug(f"- Write {line_nrs} lines to {file}")
with file.open('w') as target_fp:
for i in range(line_nrs):
target_fp.write(source_fp.readline())
logger.info('Stopping')