Tools for blacklisting tracks

This commit is contained in:
Ruben van de Ven 2024-11-28 16:08:55 +01:00
parent 30648b9bb8
commit a590a0dc35
10 changed files with 388 additions and 96 deletions

1
.gitignore vendored
View file

@ -1,6 +1,7 @@
.idea/
OUT/
EXPERIMENTS/
runs/
## Core latex/pdflatex auxiliary files:
*.aux

16
poetry.lock generated
View file

@ -1015,6 +1015,20 @@ files = [
[package.extras]
dev = ["hypothesis"]
[[package]]
name = "jsonlines"
version = "4.0.0"
description = "Library with helpers for the jsonlines file format"
optional = false
python-versions = ">=3.8"
files = [
{file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"},
{file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"},
]
[package.dependencies]
attrs = ">=19.2.0"
[[package]]
name = "jsonpointer"
version = "2.4"
@ -3725,4 +3739,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10,<3.12,"
content-hash = "bf4feafd4afa6ceb39a1c599e3e7cdc84afbe11ab1672b49e5de99ad44568b08"
content-hash = "71868029d1943c412082bcb705dd76711f105afa8a93326a45e129230de8ffa9"

View file

@ -10,6 +10,8 @@ trapserv = "trap.plumber:start"
tracker = "trap.tools:tracker_preprocess"
compare = "trap.tools:tracker_compare"
process_data = "trap.process_data:main"
blacklist = "trap.tools:blacklist_tracks"
rewrite_tracks = "trap.tools:rewrite_raw_track_files"
[tool.poetry.dependencies]
@ -39,6 +41,7 @@ pyglet-cornerpin = "^0.3.0"
opencv-python = {file="./opencv_python-4.10.0.84-cp310-cp310-linux_x86_64.whl"}
setproctitle = "^1.3.3"
bytetracker = { git = "https://github.com/rubenvandeven/bytetrack-pip" }
jsonlines = "^4.0.0"
[build-system]
requires = ["poetry-core"]

View file

@ -86,7 +86,7 @@ class CameraAction(argparse.Action):
# 'camera_matrix': np.array(data['camera_matrix']),
# 'dist_coeff': np.array(data['dist_coeff']),
# }
camera = Camera(np.array(data['camera_matrix']), np.array(data['dist_coeff']), data['dim']['width'], data['dim']['height'], namespace.H)
camera = Camera(np.array(data['camera_matrix']), np.array(data['dist_coeff']), data['dim']['width'], data['dim']['height'], namespace.H, namespace.camera_fps)
setattr(namespace, 'camera', camera)
@ -276,6 +276,10 @@ frame_emitter_parser.add_argument("--video-loop",
# Tracker
tracker_parser.add_argument("--camera-fps",
help="Camera FPS",
type=int,
default=12)
tracker_parser.add_argument("--homography",
help="File with homography params",
type=Path,
@ -314,6 +318,9 @@ tracker_parser.add_argument("--smooth-tracks",
# Renderer
# render_parser.add_argument("--disable-renderer",
# help="Disable the renderer all together. Usefull when using an external renderer",
# action="store_true")
render_parser.add_argument("--render-file",
help="Render a video file previewing the prediction, and its delay compared to the current frame",

View file

@ -1,7 +1,11 @@
from __future__ import annotations
from argparse import Namespace
from dataclasses import dataclass, field
import dataclasses
from enum import IntFlag
from itertools import cycle
import json
import logging
from multiprocessing import Event
from pathlib import Path
@ -11,21 +15,36 @@ import time
from typing import Iterable, List, Optional
import numpy as np
import cv2
import pandas as pd
import zmq
import os
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
from bytetracker.byte_tracker import STrack as ByteTrackTrack
from bytetracker.basetrack import TrackState as ByteTrackTrackState
from trajectron.environment import Environment, Node, Scene
from urllib.parse import urlparse
from trap.utils import lerp
logger = logging.getLogger('trap.frame_emitter')
class DataclassJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, np.ndarray):
return o.tolist()
if dataclasses.is_dataclass(o):
d = dataclasses.asdict(o)
if isinstance(o, Frame):
# Don't send images over JSON
del d['img']
return d
return super().default(o)
class UrlOrPath():
def __init__(self, str):
self.url = urlparse(str)
def __init__(self, string):
self.url = urlparse(str(string))
def __str__(self) -> str:
return self.url.geturl()
@ -64,14 +83,29 @@ class DetectionState(IntFlag):
return cls.Confirmed
raise RuntimeError("Should not run into Deleted entries here")
@dataclass
class Camera:
def __init__(self, mtx, dist, w, h, H):
self.mtx = mtx
self.dist = dist
self.w = w
self.h = h
self.newcameramtx, self.roi = cv2.getOptimalNewCameraMatrix(mtx, dist, (w,h), 1, (w,h))
self.H = H # homography
mtx: cv2.Mat
dist: cv2.Mat
w: float
h: float
H: cv2.Mat # homography
newcameramtx: cv2.Mat = field(init=False)
roi: cv2.typing.Rect = field(init=False)
fps: float
def __post_init__(self):
self.newcameramtx, self.roi = cv2.getOptimalNewCameraMatrix(self.mtx, self.dist, (self.w,self.h), 1, (self.w,self.h))
# def __init__(self, mtx, dist, w, h, H):
# self.mtx = mtx
# self.dist = dist
# self.w = w
# self.h = h
# self.newcameramtx, self.roi = cv2.getOptimalNewCameraMatrix(mtx, dist, (w,h), 1, (w,h))
# self.H = H # homography
@dataclass
@ -131,6 +165,7 @@ class Track:
history: List[Detection] = field(default_factory=lambda: [])
predictor_history: Optional[list] = None # in image space
predictions: Optional[list] = None
fps: int = 12
def get_projected_history(self, H, camera: Optional[Camera]= None) -> np.array:
foot_coordinates = [d.get_foot_coords() for d in self.history]
@ -138,7 +173,7 @@ class Track:
if len(foot_coordinates):
if camera:
coords = cv2.undistortPoints(np.array([foot_coordinates]).astype('float32'), camera.mtx, camera.dist, None, camera.newcameramtx)
coords = cv2.perspectiveTransform(np.array(coords),H)
coords = cv2.perspectiveTransform(np.array(coords),camera.H)
return coords.reshape((coords.shape[0],2))
else:
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
@ -149,7 +184,63 @@ class Track:
coords = self.get_projected_history(H, camera)
return [{"x":c[0], "y":c[1]} for c in coords]
def get_with_interpolated_history(self) -> Track:
# new_history = [Detection(d.track_id, l, t, w, h, d.conf, d.state, d.frame_nr, d.det_class) for l, t, w, h, d in zip(ls,ts,ws,hs, track.history)]
# new_track = Track(track.track_id, new_history, track.predictor_history, track.predictions)
new_history = []
for j in range(len(self.history)-1):
a = self.history[j]
b = self.history[j+1]
gap = b.frame_nr - a.frame_nr
new_history.append(Detection(a.track_id, a.l, a.t, a.w, a.h, a.conf, a.state, a.frame_nr, a.det_class))
if gap < 1:
logger.error(f"WARNING, gap between frames {a.frame_nr} -> {b.frame_nr} is negative?")
if gap > 1:
for g in range(1, gap):
l = lerp(a.l, b.l, g/gap)
t = lerp(a.t, b.t, g/gap)
w = lerp(a.w, b.w, g/gap)
h = lerp(a.h, b.h, g/gap)
conf = 0
state = DetectionState.Lost
frame_nr = a.frame_nr + g
new_history.append(Detection(a.track_id, l, t, w, h, conf, state, frame_nr, a.det_class))
return Track(
self.track_id,
new_history,
self.predictor_history,
self.predictions,
self.fps)
def to_trajectron_node(self, camera: Camera, env: Environment) -> Node:
positions = self.get_projected_history(None, camera)
velocity = np.gradient(positions, self.fps, axis=0)
acceleration = np.gradient(velocity, self.fps, axis=0)
new_first_idx = self.history[0].frame_nr
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
# vx = derivative_of(x, scene.dt)
# vy = derivative_of(y, scene.dt)
# ax = derivative_of(vx, scene.dt)
# ay = derivative_of(vy, scene.dt)
data_dict = {
('position', 'x'): positions[:,0],
('position', 'y'): positions[:,1],
('velocity', 'x'): velocity[:,0],
('velocity', 'y'): velocity[:,1],
('acceleration', 'x'): acceleration[:,0],
('acceleration', 'y'): acceleration[:,1]
}
node_data = pd.DataFrame(data_dict, columns=data_columns)
return Node(node_type=env.NodeType.PEDESTRIAN, node_id=self.track_id, data=node_data, first_timestep=new_first_idx)

View file

@ -86,7 +86,7 @@ def start():
# instantiating process with arguments
procs = [
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
# ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
]

View file

@ -26,7 +26,7 @@ import matplotlib.pyplot as plt
import zmq
from trap.frame_emitter import Frame
from trap.frame_emitter import DataclassJSONEncoder, Frame
from trap.tracker import Track, Smoother
logger = logging.getLogger("trap.prediction")
@ -160,8 +160,16 @@ class PredictionServer:
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
self.external_predictions = not self.config.zmq_prediction_addr.startswith("ipc://")
# print(self.prediction_socket)
def send_frame(self, frame: Frame):
if self.external_predictions:
# data = json.dumps(frame, cls=DataclassJSONEncoder)
self.prediction_socket.send_json(frame, cls=DataclassJSONEncoder)
else:
self.prediction_socket.send_pyobj(frame)
def run(self):
if self.config.seed is not None:
@ -275,12 +283,14 @@ class PredictionServer:
if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else:
# print('await', self.config.zmq_trajectory_addr)
zmq_ev = self.trajectory_socket.poll(timeout=2000)
if not zmq_ev:
# on no data loop so that is_running is checked
continue
data = self.trajectory_socket.recv()
# print('recv tracker frame')
frame: Frame = pickle.loads(data)
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data)
@ -330,7 +340,7 @@ class PredictionServer:
first_timestep=timestep
)
input_dict[node] = np.array([x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
input_dict[node] = np.array(object=[x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
# print(input_dict)
@ -340,7 +350,7 @@ class PredictionServer:
# And want to update the network
# data = json.dumps({})
self.prediction_socket.send_pyobj(frame)
self.send_frame(frame)
continue
@ -444,7 +454,7 @@ class PredictionServer:
if self.config.smooth_predictions:
frame = self.smoother.smooth_frame_predictions(frame)
self.prediction_socket.send_pyobj(frame)
self.send_frame(frame)
time.sleep(.5)
logger.info('Stopping')

View file

@ -1,15 +1,17 @@
from argparse import Namespace
import json
import math
from pathlib import Path
import pickle
from tempfile import mktemp
import jsonlines
import numpy as np
import pandas as pd
import trap.tracker
from trap.config import parser
from trap.frame_emitter import Detection, DetectionState, video_src_from_config, Frame
from trap.tracker import DETECTOR_YOLOv8, Smoother, _yolov8_track, Track, TrainingDataWriter, Tracker
from trap.tracker import DETECTOR_YOLOv8, Smoother, _yolov8_track, Track, TrainingDataWriter, Tracker, read_tracks_json
from collections import defaultdict
import logging
@ -20,6 +22,8 @@ from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult
import tqdm
from trap.utils import lerp
logger = logging.getLogger('tools')
@ -98,6 +102,7 @@ def tracker_preprocess():
total = 0
frames = FrameGenerator(config)
total_tracks = set()
for frame in frames:
bar.update()
@ -105,7 +110,6 @@ def tracker_preprocess():
total += len(detections)
# detections = _yolov8_track(frame, model, imgsz=1440, classes=[0])
bar.set_description(f"{frames.video_nr}/{len(frames.video_srcs)} [{frames.frame_idx}/{frames.frame_count}] {marquee_string(str(frames.video_path), 10, frames.n//2)} | dets {len(detections)}: {[d.track_id for d in detections]} (∑{total})")
for detection in detections:
track = tracks[detection.track_id]
@ -114,6 +118,9 @@ def tracker_preprocess():
active_track_ids = [d.track_id for d in detections]
active_tracks = {t.track_id: t for t in tracks.values() if t.track_id in active_track_ids}
total_tracks.update(active_track_ids)
bar.set_description(f"{frames.video_nr}/{len(frames.video_srcs)} [{frames.frame_idx}/{frames.frame_count}] {marquee_string(str(frames.video_path), 10, frames.n//2)} | dets {len(detections)}: {[d.track_id for d in detections]} (∑{total}{len(total_tracks)})")
writer.add(frame, active_tracks.values())
@ -122,12 +129,13 @@ def tracker_preprocess():
bgr_colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
# (0, 0, 255),# red used for missing waypoints
(0, 255, 255),
]
def detection_color(detection: Detection, i):
return bgr_colors[i % len(bgr_colors)] if detection.state != DetectionState.Lost else (100,100,100)
def detection_color(detection: Detection, i, prev_detection: Optional[Detection] = None):
vague = detection.state == DetectionState.Lost or (prev_detection and detection.frame_nr - prev_detection.frame_nr > 1)
return bgr_colors[i % len(bgr_colors)] if not vague else (0,0,255)
def to_point(coord):
return (int(coord[0]), int(coord[1]))
@ -164,13 +172,7 @@ def tracker_compare():
for i, (tracker, detections) in enumerate(trackers_detections):
for track_id in tracker.tracks:
history = tracker.tracks[track_id].history
cv2.putText(frame.img, f"{track_id}", to_point(history[0].get_foot_coords()), cv2.FONT_HERSHEY_DUPLEX, 1, color=bgr_colors[i % len(bgr_colors)])
for j in range(len(history)-1):
a = history[j]
b = history[j+1]
color = detection_color(b, i)
cv2.line(frame.img, to_point(a.get_foot_coords()), to_point(b.get_foot_coords()), color, 1)
draw_track(frame.img, tracker.tracks[track_id], i)
for detection in detections:
color = color = detection_color(detection, i)
l, t, r, b = detection.to_ltrb()
@ -184,6 +186,97 @@ def tracker_compare():
bar.set_description(f"[{frames.video_nr}/{len(frames.video_srcs)}] [{frames.frame_idx}/{frames.frame_count}] {str(frames.video_path)}")
def draw_track(img: cv2.Mat, track: Track, color_index: int):
history = track.history
cv2.putText(img, f"{track.track_id} ({len(history)})", to_point(history[0].get_foot_coords()), cv2.FONT_HERSHEY_DUPLEX, 1, color=bgr_colors[color_index % len(bgr_colors)])
point_color = detection_color(history[0], color_index)
cv2.circle(img, to_point(history[0].get_foot_coords()), 3, point_color, 2)
for j in range(len(history)-1):
a = history[j]
b = history[j+1]
# TODO)) replace with Track.get_with_interpolated_history()
gap = b.frame_nr - a.frame_nr - 1
if gap < 0:
print(f"WARNING, gap between frames {a.frame_nr} -> {b.frame_nr} is negative?")
if gap > 0:
for g in range(gap):
p1 = a.get_foot_coords()
p2 = b.get_foot_coords()
point = (lerp(p1[0], p2[0], g/gap), lerp(p1[1], p2[1], g/gap))
cv2.circle(img, to_point(point), 3, (0,0,255), 1)
color = detection_color(b, color_index, a)
cv2.line(img, to_point(a.get_foot_coords()), to_point(b.get_foot_coords()), color, 1)
point_color = detection_color(b, color_index)
cv2.circle(img, to_point(b.get_foot_coords()), 3, point_color, 2)
def blacklist_tracks():
config = parser.parse_args()
cv2.namedWindow("frame", cv2.WND_PROP_FULLSCREEN)
cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
backdrop = cv2.imread('../DATASETS/hof3/output.png')
blacklist = []
path: Path = config.save_for_training
blacklist_file = path / "blacklist.jsonl"
whitelist_file = path / "whitelist.jsonl" # for skipping
tracks_file = path / "tracks.json"
FPS = 12 # TODO)) From config
if whitelist_file.exists():
# with whitelist_file.open('r') as fp:
with jsonlines.open(whitelist_file, 'r') as reader:
whitelist = [l for l in reader.iter(type=str)]
else:
whitelist = []
try:
for track in read_tracks_json(tracks_file, blacklist_file, FPS):
if track.track_id in whitelist:
logger.info(f'skip whitelisted {track.track_id}')
continue
img = backdrop.copy()
draw_track(img, track, 0)
imgS = cv2.resize(img, (1920, 1080))
cv2.imshow('frame', imgS)
while True:
k = cv2.waitKey(0)
if k==27: # Esc key to stop
raise StopIteration
elif k == ord('s'):
break # skip for now
elif k == ord('y'):
with jsonlines.open(whitelist_file, mode='a') as writer:
# skip next time around
writer.write(track.track_id)
break
elif k == ord('n'):
print('blacklist', track.track_id)
logger.info(f"Append {len(blacklist)} items to {str(blacklist_file)}")
with jsonlines.open(blacklist_file, mode='a') as writer:
writer.write(track.track_id)
break
else:
# ignore all other keypresses
print(k) # else print its value
continue
except StopIteration as e:
pass
def rewrite_raw_track_files():
logging.basicConfig(level=logging.DEBUG)
config = parser.parse_args()
trap.tracker.rewrite_raw_track_files(config.save_for_training)
def interpolate_missing_frames(data: pd.DataFrame):
missing=0

View file

@ -4,11 +4,13 @@ import csv
from dataclasses import dataclass, field
import json
import logging
from math import nan
from multiprocessing import Event
from pathlib import Path
import pickle
import time
from typing import Optional, List
import jsonlines
import numpy as np
import torch
import torchvision
@ -28,7 +30,7 @@ from bytetracker import BYTETracker
from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother
import tsmoothie.smoother
from datetime import datetime
from datetime import datetime, timedelta
# Detection = [int, int, int, int, float, int]
# Detections = [Detection]
@ -89,6 +91,43 @@ class Multifile():
def readline(self):
return self.g.__next__()
FIELDNAMES = ['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state']
def read_tracks_json(path: Path, blacklist_path: Path, fps):
"""
Reader for tracks.json produced by TrainingDataWriter
"""
with path.open('r') as fp:
tracks_dict: dict = json.load(fp)
if blacklist_path.exists():
with jsonlines.open(blacklist_path, 'r') as reader:
blacklist = [track_id for track_id in reader.iter(type=str)]
else:
blacklist = []
for track_id, detection_values in tracks_dict.items():
if track_id in blacklist:
continue
history = []
# for detection_values in
for detection_items in detection_values:
d = dict(zip(FIELDNAMES, detection_items))
history.append(Detection(
d['track_id'],
d['l'],
d['t'],
d['w'],
d['h'],
nan,
d['state'],
d['frame_id'],
1
))
yield Track(track_id, history, fps=fps)
class TrainingDataWriter:
def __init__(self, training_path: Optional[Path]):
@ -114,7 +153,7 @@ class TrainingDataWriter:
self.training_fp = open(self.path / f'all-{d}.txt', 'w')
logger.debug(f"Writing tracker data to {self.training_fp.name}")
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
self.csv = csv.DictWriter(self.training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
self.csv = csv.DictWriter(self.training_fp, fieldnames=FIELDNAMES, delimiter='\t', quoting=csv.QUOTE_NONE)
self.count = 0
return self
@ -146,8 +185,12 @@ class TrainingDataWriter:
return
self.training_fp.close()
rewrite_raw_track_files(self.path)
source_files = list(self.path.glob("*.txt")) # we loop twice, so need a list instead of generator
def rewrite_raw_track_files(path: Path):
source_files = list(sorted(path.glob("*.txt"))) # we loop twice, so need a list instead of generator
total = 0
sources = Multifile(source_files)
for line in sources:
@ -155,36 +198,49 @@ class TrainingDataWriter:
total += 1
lines = {
destinations = {
'train': int(total * .8),
'val': int(total * .12),
'test': int(total * .08),
}
logger.info(f"Splitting gathered data from {sources.name}")
logger.info(f"Splitting gathered data from {source_files}")
# for source_file in source_files:
tracks_file = self.path / 'tracks.json'
tracks_file = path / 'tracks.json'
tracks = defaultdict(lambda: [])
for name, line_nrs in lines.items():
dir_path = self.path / name
offset = 0
max_track_id = 0
prev_file = None
# all-2024-11-12T13:30.txt
file_date = None
for name, line_nrs in destinations.items():
dir_path = path / name
dir_path.mkdir(exist_ok=True)
file = dir_path / 'tracked.txt'
logger.debug(f"- Write {line_nrs} lines to {file}")
with file.open('w') as target_fp:
max_track_id = 0
offset = 0
prev_file = None
for i in range(line_nrs):
line = sources.readline()
current_file = sources.current_file
if prev_file != current_file:
offset = max_track_id
offset: int = max_track_id
logger.debug(f'{name} - update offset {offset} ({sources.current_file})')
logger.info(f'{name} - update offset {offset} ({sources.current_file})')
prev_file = current_file
file_date = datetime.strptime(current_file.name, 'all-%Y-%m-%dT%H:%M.txt')
if file_date:
frame_date = file_date + timedelta(seconds = int(parts[0])//10)
else:
frame_date = None
parts = line.split('\t')
track_id = int(parts[1]) + offset
@ -193,13 +249,18 @@ class TrainingDataWriter:
parts[1] = str(track_id)
target_fp.write("\t".join(parts))
tracks[track_id].append(parts)
parts = [float(p) for p in parts]
tracks[track_id].append([
int(parts[0] / 10),
track_id,
] + parts[2:8] + [int(parts[8])])
with tracks_file.open('w') as fp:
logger.info(f"Write {len(tracks)} tracks to {str(tracks_file)}")
json.dump(tracks, fp)
class TrackerWrapper():
def __init__(self, tracker):
self.tracker = tracker
@ -317,6 +378,7 @@ class Tracker:
for detection in detections:
track = self.tracks[detection.track_id]
track.track_id = detection.track_id # for new tracks
track.fps = self.config.camera.fps # for new tracks
track.history.append(detection) # add to history
@ -441,6 +503,7 @@ class Tracker:
if self.config.smooth_tracks:
frame = self.smoother.smooth_frame_tracks(frame)
# print(f"send to {self.trajectory_socket}, {self.config.zmq_trajectory_addr}")
self.trajectory_socket.send_pyobj(frame)
end_time = time.time()
@ -550,10 +613,7 @@ class Smoother:
self.smoother.smooth(points)
return self.smoother.smooth_data[0]
def smooth_frame_tracks(self, frame: Frame) -> Frame:
new_tracks = []
for track in frame.tracks.values():
def smooth_track(self, track: Track) -> Track:
ls = [d.l for d in track.history]
ts = [d.t for d in track.history]
ws = [d.w for d in track.history]
@ -567,7 +627,12 @@ class Smoother:
self.smoother.smooth(hs)
hs = self.smoother.smooth_data[0]
new_history = [Detection(d.track_id, l, t, w, h, d.conf, d.state, d.frame_nr, d.det_class) for l, t, w, h, d in zip(ls,ts,ws,hs, track.history)]
new_track = Track(track.track_id, new_history, track.predictor_history, track.predictions)
new_track = Track(track.track_id, new_history, track.predictor_history, track.predictions, track.fps)
def smooth_frame_tracks(self, frame: Frame) -> Frame:
new_tracks = []
for track in frame.tracks.values():
new_track = self.smooth_track(track)
new_tracks.append(new_track)
frame.tracks = {t.track_id: t for t in new_tracks}
return frame

8
trap/utils.py Normal file
View file

@ -0,0 +1,8 @@
def lerp(a: float, b: float, t: float) -> float:
"""Linear interpolate on the scale given by a to b, using t as the point on that scale.
Examples
--------
50 == lerp(0, 100, 0.5)
4.2 == lerp(1, 5, 0.8)
"""
return (1 - t) * a + t * b