Compare commits
3 commits
3d8cb7ef70
...
9284ce8849
Author | SHA1 | Date | |
---|---|---|---|
|
9284ce8849 | ||
|
a0c63c4929 | ||
|
2e2bd76b05 |
10 changed files with 857 additions and 175 deletions
22
poetry.lock
generated
22
poetry.lock
generated
|
@ -2290,15 +2290,29 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pyglet"
|
||||
version = "2.0.15"
|
||||
version = "2.0.18"
|
||||
description = "pyglet is a cross-platform games and multimedia package."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyglet-2.0.15-py3-none-any.whl", hash = "sha256:9e4cc16efc308106fd3a9ff8f04e7a6f4f6a807c6ac8a331375efbbac8be85af"},
|
||||
{file = "pyglet-2.0.15.tar.gz", hash = "sha256:42085567cece0c7f1c14e36eef799938cbf528cfbb0150c484b984f3ff1aa771"},
|
||||
{file = "pyglet-2.0.18-py3-none-any.whl", hash = "sha256:e592952ae0297e456c587b6486ed8c3e5f9d0c3519d517bb92dde5fdf4c26b41"},
|
||||
{file = "pyglet-2.0.18.tar.gz", hash = "sha256:7cf9238d70082a2da282759679f8a011cc979753a32224a8ead8ed80e48f99dc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyglet-cornerpin"
|
||||
version = "0.2.0"
|
||||
description = "Add a corner pin transform to a pyglet window"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
files = [
|
||||
{file = "pyglet_cornerpin-0.2.0-py3-none-any.whl", hash = "sha256:1e1cf4f2e86929fb74e89939be8f7ebdb110f65bf0923e51466e8fbd44773dc5"},
|
||||
{file = "pyglet_cornerpin-0.2.0.tar.gz", hash = "sha256:8fe8a7618c11f93ac3b3c8b89b71e4398bf1223eea9ac3ea744e9d36031a44f9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyglet = ">=2.0.18,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.17.2"
|
||||
|
@ -3528,4 +3542,4 @@ watchdog = ["watchdog (>=2.3)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10,<3.12,"
|
||||
content-hash = "5154a99d490755a68e51595424649b5269fcd17ef14094c6285f5de7f972f110"
|
||||
content-hash = "bffa0878a620996b47aa5623b951f09ab010c267880c6dcd5a53741f244e675a"
|
||||
|
|
|
@ -7,6 +7,7 @@ readme = "README.md"
|
|||
|
||||
[tool.poetry.scripts]
|
||||
trapserv = "trap.plumber:start"
|
||||
tracker = "trap.tools:tracker_preprocess"
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
@ -32,6 +33,7 @@ gdown = "^4.7.1"
|
|||
pandas-helper-calc = {git = "https://github.com/scls19fr/pandas-helper-calc"}
|
||||
tsmoothie = "^1.0.5"
|
||||
pyglet = "^2.0.15"
|
||||
pyglet-cornerpin = "^0.2.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
451
trap/animation_renderer.py
Normal file
451
trap/animation_renderer.py
Normal file
|
@ -0,0 +1,451 @@
|
|||
# used for "Forward Referencing of type annotations"
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import ffmpeg
|
||||
from argparse import Namespace
|
||||
import datetime
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
from multiprocessing.synchronize import Event as BaseEvent
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import pyglet
|
||||
import pyglet.event
|
||||
import zmq
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import math
|
||||
|
||||
from pyglet import shapes
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from trap.frame_emitter import DetectionState, Frame, Track
|
||||
from trap.preview_renderer import DrawnTrack, PROJECTION_IMG, PROJECTION_MAP
|
||||
|
||||
|
||||
logger = logging.getLogger("trap.renderer")
|
||||
|
||||
class AnimationRenderer:
|
||||
def __init__(self, config: Namespace, is_running: BaseEvent):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
context = zmq.Context()
|
||||
self.prediction_sock = context.socket(zmq.SUB)
|
||||
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.prediction_sock.connect(config.zmq_prediction_addr)
|
||||
|
||||
self.tracker_sock = context.socket(zmq.SUB)
|
||||
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.tracker_sock.connect(config.zmq_trajectory_addr)
|
||||
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.connect(config.zmq_frame_addr)
|
||||
|
||||
self.H = self.config.H
|
||||
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
self.fps = 60
|
||||
self.frame_size = (self.config.frame_width,self.config.frame_height)
|
||||
self.hide_stats = False
|
||||
self.out_writer = None # self.start_writer() if self.config.render_file else None
|
||||
self.streaming_process = None # self.start_streaming() if self.config.render_url else None
|
||||
|
||||
if self.config.render_window:
|
||||
pass
|
||||
# cv2.namedWindow("frame", cv2.WND_PROP_FULLSCREEN)
|
||||
# cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
|
||||
else:
|
||||
pyglet.options["headless"] = True
|
||||
|
||||
config = pyglet.gl.Config(sample_buffers=1, samples=4)
|
||||
# , fullscreen=self.config.render_window
|
||||
|
||||
display = pyglet.canvas.get_display()
|
||||
screen = display.get_screens()[0]
|
||||
|
||||
# self.window = pyglet.window.Window(width=self.frame_size[0], height=self.frame_size[1], config=config, fullscreen=False, screen=screens[1])
|
||||
self.window = pyglet.window.Window(width=screen.width, height=screen.height, config=config, fullscreen=True, screen=screen)
|
||||
self.window.set_handler('on_draw', self.on_draw)
|
||||
self.window.set_handler('on_refresh', self.on_refresh)
|
||||
self.window.set_handler('on_close', self.on_close)
|
||||
|
||||
# don't know why, but importing this before window leads to "x connection to :1 broken (explicit kill or server shutdown)"
|
||||
from pyglet_cornerpin import PygletCornerPin
|
||||
|
||||
self.pins = PygletCornerPin(self.window)
|
||||
self.window.push_handlers(self.pins)
|
||||
|
||||
pyglet.gl.glClearColor(0,0,0, 0)
|
||||
self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=(255,255,255,255))
|
||||
self.fps_display.label.x = self.window.width - 50
|
||||
self.fps_display.label.y = self.window.height - 17
|
||||
self.fps_display.label.bold = False
|
||||
self.fps_display.label.font_size = 10
|
||||
|
||||
self.drawn_tracks: dict[str, DrawnTrack] = {}
|
||||
|
||||
|
||||
self.first_time: float|None = None
|
||||
self.frame: Frame|None= None
|
||||
self.tracker_frame: Frame|None = None
|
||||
self.prediction_frame: Frame|None = None
|
||||
|
||||
|
||||
self.batch_bg = pyglet.graphics.Batch()
|
||||
self.batch_overlay = pyglet.graphics.Batch()
|
||||
self.batch_anim = pyglet.graphics.Batch()
|
||||
|
||||
self.debug_lines = [
|
||||
pyglet.shapes.Line(1380, self.config.camera.h, 1380, 670, 2, (255,255,255,255), batch=self.batch_overlay),
|
||||
pyglet.shapes.Line(0, 660, 1380, 670, 2, (255,255,255,255), batch=self.batch_overlay),
|
||||
pyglet.shapes.Line(1140, 760, 1140, 675, 2, (255,255,255,255), batch=self.batch_overlay),
|
||||
pyglet.shapes.Line(0, 750, 1380, 760, 2, (255,255,255,255), batch=self.batch_overlay),
|
||||
|
||||
]
|
||||
|
||||
|
||||
self.init_shapes()
|
||||
|
||||
self.init_labels()
|
||||
|
||||
|
||||
def init_shapes(self):
|
||||
'''
|
||||
Due to error when running headless, we need to configure options before extending the shapes class
|
||||
'''
|
||||
class GradientLine(shapes.Line):
|
||||
def __init__(self, x, y, x2, y2, width=1, color1=[255,255,255], color2=[255,255,255], batch=None, group=None):
|
||||
# print('colors!', colors)
|
||||
# assert len(colors) == 6
|
||||
|
||||
r, g, b, *a = color1
|
||||
self._rgba1 = (r, g, b, a[0] if a else 255)
|
||||
r, g, b, *a = color2
|
||||
self._rgba2 = (r, g, b, a[0] if a else 255)
|
||||
|
||||
# print('rgba', self._rgba)
|
||||
|
||||
super().__init__(x, y, x2, y2, width, color1, batch=None, group=None)
|
||||
# <pyglet.graphics.vertexdomain.VertexList
|
||||
# pyglet.graphics.vertexdomain
|
||||
# print(self._vertex_list)
|
||||
|
||||
def _create_vertex_list(self):
|
||||
'''
|
||||
copy of super()._create_vertex_list but with additional colors'''
|
||||
self._vertex_list = self._group.program.vertex_list(
|
||||
6, self._draw_mode, self._batch, self._group,
|
||||
position=('f', self._get_vertices()),
|
||||
colors=('Bn', self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1 ),
|
||||
translation=('f', (self._x, self._y) * self._num_verts))
|
||||
|
||||
def _update_colors(self):
|
||||
self._vertex_list.colors[:] = self._rgba1+ self._rgba2 + self._rgba2 + self._rgba1 + self._rgba2 +self._rgba1
|
||||
|
||||
def color1(self, color):
|
||||
r, g, b, *a = color
|
||||
self._rgba1 = (r, g, b, a[0] if a else 255)
|
||||
self._update_colors()
|
||||
|
||||
def color2(self, color):
|
||||
r, g, b, *a = color
|
||||
self._rgba2 = (r, g, b, a[0] if a else 255)
|
||||
self._update_colors()
|
||||
|
||||
self.gradientLine = GradientLine
|
||||
|
||||
def init_labels(self):
|
||||
base_color = (255,)*4
|
||||
color_predictor = (255,255,0, 255)
|
||||
color_info = (255,0, 255, 255)
|
||||
color_tracker = (0,255, 255, 255)
|
||||
|
||||
options = []
|
||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
options.append(f"{option}: {self.config.__dict__[option]}")
|
||||
|
||||
self.labels = {
|
||||
'waiting': pyglet.text.Label("Waiting for prediction"),
|
||||
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
|
||||
'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
|
||||
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
|
||||
}
|
||||
|
||||
def refresh_labels(self, dt: float):
|
||||
"""Every frame"""
|
||||
|
||||
if self.frame:
|
||||
self.labels['frame_idx'].text = f"{self.frame.index:06d}"
|
||||
self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
|
||||
self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
|
||||
|
||||
if self.tracker_frame:
|
||||
self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}"
|
||||
self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s"
|
||||
self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks"
|
||||
|
||||
if self.prediction_frame:
|
||||
self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
|
||||
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
|
||||
# self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
||||
|
||||
|
||||
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
# if prediction_frame:
|
||||
# # render Δt and Δ frames
|
||||
# cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()]):.2f}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
# cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
|
||||
|
||||
# options = []
|
||||
# for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
# options.append(f"{option}: {config.__dict__[option]}")
|
||||
|
||||
|
||||
# cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
# cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
||||
|
||||
|
||||
def check_frames(self, dt):
|
||||
new_tracks = False
|
||||
try:
|
||||
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
if not self.first_time:
|
||||
self.first_time = self.frame.time
|
||||
img = self.frame.img
|
||||
# newcameramtx, roi = cv2.getOptimalNewCameraMatrix(self.config.camera.mtx, self.config.camera.dist, (self.frame.img.shape[1], self.frame.img.shape[0]), 1, (self.frame.img.shape[1], self.frame.img.shape[0]))
|
||||
img = cv2.undistort(img, self.config.camera.mtx, self.config.camera.dist, None, self.config.camera.newcameramtx)
|
||||
img = cv2.warpPerspective(img, self.H, (self.frame.img.shape[1], self.frame.img.shape[0]))
|
||||
img = cv2.GaussianBlur(img, (15, 15), 0)
|
||||
img = cv2.flip(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 0)
|
||||
img = pyglet.image.ImageData(self.frame_size[0], self.frame_size[1], 'RGB', img.tobytes())
|
||||
# don't draw in batch, so that it is the background
|
||||
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
|
||||
self.video_sprite.opacity = 30
|
||||
except zmq.ZMQError as e:
|
||||
# idx = frame.index if frame else "NONE"
|
||||
# logger.debug(f"reuse video frame {idx}")
|
||||
pass
|
||||
try:
|
||||
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
try:
|
||||
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
|
||||
if new_tracks:
|
||||
self.update_tracks()
|
||||
|
||||
def update_tracks(self):
|
||||
"""Updates the track objects and shapes. Called after setting `prediction_frame`
|
||||
"""
|
||||
|
||||
# clean up
|
||||
# for track_id in list(self.drawn_tracks.keys()):
|
||||
# if track_id not in self.prediction_frame.tracks.keys():
|
||||
# # TODO fade out
|
||||
# del self.drawn_tracks[track_id]
|
||||
|
||||
if self.prediction_frame:
|
||||
for track_id, track in self.prediction_frame.tracks.items():
|
||||
if track_id not in self.drawn_tracks:
|
||||
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H, PROJECTION_MAP, self.config.camera)
|
||||
else:
|
||||
self.drawn_tracks[track_id].set_track(track)
|
||||
|
||||
# clean up
|
||||
for track_id in list(self.drawn_tracks.keys()):
|
||||
# TODO make delay configurable
|
||||
if self.drawn_tracks[track_id].update_at < time.time() - 5:
|
||||
# TODO fade out
|
||||
del self.drawn_tracks[track_id]
|
||||
|
||||
|
||||
def on_key_press(self, symbol, modifiers):
|
||||
print('A key was pressed, use f to hide')
|
||||
if symbol == ord('f'):
|
||||
self.window.set_fullscreen(not self.window.fullscreen)
|
||||
if symbol == ord('h'):
|
||||
self.hide_stats = not self.hide_stats
|
||||
|
||||
def check_running(self, dt):
|
||||
if not self.is_running.is_set():
|
||||
self.window.close()
|
||||
self.event_loop.exit()
|
||||
|
||||
def on_close(self):
|
||||
self.is_running.clear()
|
||||
|
||||
|
||||
def on_refresh(self, dt: float):
|
||||
# update shapes
|
||||
# self.bg =
|
||||
for track_id, track in self.drawn_tracks.items():
|
||||
track.update_drawn_positions(dt)
|
||||
|
||||
|
||||
self.refresh_labels(dt)
|
||||
|
||||
# self.shape1 = shapes.Circle(700, 150, 100, color=(50, 0, 30), batch=self.batch_anim)
|
||||
# self.shape3 = shapes.Circle(800, 150, 100, color=(100, 225, 30), batch=self.batch_anim)
|
||||
pass
|
||||
|
||||
def on_draw(self):
|
||||
self.window.clear()
|
||||
|
||||
self.batch_bg.draw()
|
||||
|
||||
for track in self.drawn_tracks.values():
|
||||
for shape in track.shapes:
|
||||
shape.draw() # for some reason the batches don't work
|
||||
for track in self.drawn_tracks.values():
|
||||
for shapes in track.pred_shapes:
|
||||
for shape in shapes:
|
||||
shape.draw()
|
||||
# self.batch_anim.draw()
|
||||
self.batch_overlay.draw()
|
||||
self.pins.draw()
|
||||
|
||||
# pyglet.graphics.draw(3, pyglet.gl.GL_LINE, ("v2i", (100,200, 600,800)), ('c3B', (255,255,255, 255,255,255)))
|
||||
|
||||
if not self.hide_stats:
|
||||
self.fps_display.draw()
|
||||
|
||||
# if streaming, capture buffer and send
|
||||
try:
|
||||
if self.streaming_process or self.out_writer:
|
||||
buf = pyglet.image.get_buffer_manager().get_color_buffer()
|
||||
img_data = buf.get_image_data()
|
||||
data = img_data.get_data() # alternative: .get_data("RGBA", image_data.pitch)
|
||||
img = np.asanyarray(data).reshape((img_data.height, img_data.width, 4))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
img = np.flip(img, 0)
|
||||
# img = cv2.flip(img, cv2.0)
|
||||
|
||||
# cv2.imshow('frame', img)
|
||||
# cv2.waitKey(1)
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.write(img.tobytes())
|
||||
if self.out_writer:
|
||||
self.out_writer.write(img)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
|
||||
def run(self):
|
||||
frame = None
|
||||
prediction_frame = None
|
||||
tracker_frame = None
|
||||
|
||||
i=0
|
||||
first_time = None
|
||||
|
||||
self.event_loop = pyglet.app.EventLoop()
|
||||
pyglet.clock.schedule_interval(self.check_running, 0.1)
|
||||
pyglet.clock.schedule(self.check_frames)
|
||||
self.event_loop.run()
|
||||
|
||||
|
||||
|
||||
# while self.is_running.is_set():
|
||||
# i+=1
|
||||
|
||||
|
||||
# # zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
# # if not zmq_ev:
|
||||
# # # when no data comes in, loop so that is_running is checked
|
||||
# # continue
|
||||
|
||||
# try:
|
||||
# frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
# except zmq.ZMQError as e:
|
||||
# # idx = frame.index if frame else "NONE"
|
||||
# # logger.debug(f"reuse video frame {idx}")
|
||||
# pass
|
||||
# # else:
|
||||
# # logger.debug(f'new video frame {frame.index}')
|
||||
|
||||
|
||||
# if frame is None:
|
||||
# # might need to wait a few iterations before first frame comes available
|
||||
# time.sleep(.1)
|
||||
# continue
|
||||
|
||||
# try:
|
||||
# prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
# except zmq.ZMQError as e:
|
||||
# logger.debug(f'reuse prediction')
|
||||
|
||||
# if first_time is None:
|
||||
# first_time = frame.time
|
||||
|
||||
# img = decorate_frame(frame, prediction_frame, first_time, self.config)
|
||||
|
||||
# img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||
|
||||
# logger.debug(f"write frame {frame.time - first_time:.3f}s")
|
||||
# if self.out_writer:
|
||||
# self.out_writer.write(img)
|
||||
# if self.streaming_process:
|
||||
# self.streaming_process.stdin.write(img.tobytes())
|
||||
# if self.config.render_window:
|
||||
# cv2.imshow('frame',img)
|
||||
# cv2.waitKey(1)
|
||||
logger.info('Stopping')
|
||||
logger.info(f'used corner pins {self.pins.corners}')
|
||||
|
||||
|
||||
# if i>2:
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.close()
|
||||
if self.out_writer:
|
||||
self.out_writer.release()
|
||||
if self.streaming_process:
|
||||
# oddly wrapped, because both close and release() take time.
|
||||
self.streaming_process.wait()
|
||||
|
||||
# colorset = itertools.product([0,255], repeat=3) # but remove white
|
||||
colorset = [(0, 0, 0),
|
||||
(0, 0, 255),
|
||||
(0, 255, 0),
|
||||
(0, 255, 255),
|
||||
(255, 0, 0),
|
||||
(255, 0, 255),
|
||||
(255, 255, 0)
|
||||
]
|
||||
|
||||
|
||||
|
||||
def run_animation_renderer(config: Namespace, is_running: BaseEvent):
|
||||
renderer = AnimationRenderer(config, is_running)
|
||||
renderer.run()
|
|
@ -1,8 +1,11 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
import types
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from trap.tracker import DETECTORS
|
||||
from trap.frame_emitter import Camera
|
||||
|
||||
from pyparsing import Optional
|
||||
|
||||
|
@ -49,6 +52,43 @@ frame_emitter_parser = parser.add_argument_group('Frame emitter')
|
|||
tracker_parser = parser.add_argument_group('Tracker')
|
||||
render_parser = parser.add_argument_group('Renderer')
|
||||
|
||||
class HomographyAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
||||
if nargs is not None:
|
||||
raise ValueError("nargs not allowed")
|
||||
super().__init__(option_strings, dest, **kwargs)
|
||||
def __call__(self, parser, namespace, values: Path, option_string=None):
|
||||
if values.suffix == '.json':
|
||||
with values.open('r') as fp:
|
||||
H = np.array(json.load(fp))
|
||||
else:
|
||||
H = np.loadtxt(values, delimiter=',')
|
||||
|
||||
setattr(namespace, self.dest, values)
|
||||
setattr(namespace, 'H', H)
|
||||
|
||||
class CameraAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
||||
if nargs is not None:
|
||||
raise ValueError("nargs not allowed")
|
||||
super().__init__(option_strings, dest, **kwargs)
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values is None:
|
||||
setattr(namespace, self.dest, None)
|
||||
else:
|
||||
values = Path(values)
|
||||
with values.open('r') as fp:
|
||||
data = json.load(fp)
|
||||
# print(data)
|
||||
# print(data['camera_matrix'])
|
||||
# camera = {
|
||||
# 'camera_matrix': np.array(data['camera_matrix']),
|
||||
# 'dist_coeff': np.array(data['dist_coeff']),
|
||||
# }
|
||||
camera = Camera(np.array(data['camera_matrix']), np.array(data['dist_coeff']), namespace.frame_width, namespace.frame_height)
|
||||
|
||||
setattr(namespace, 'camera', camera)
|
||||
|
||||
inference_parser.add_argument("--model_dir",
|
||||
help="directory with the model to use for inference",
|
||||
type=str, # TODO: make into Path
|
||||
|
@ -234,7 +274,13 @@ frame_emitter_parser.add_argument("--video-loop",
|
|||
tracker_parser.add_argument("--homography",
|
||||
help="File with homography params",
|
||||
type=Path,
|
||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
|
||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt',
|
||||
action=HomographyAction)
|
||||
tracker_parser.add_argument("--calibration",
|
||||
help="File with camera intrinsics and lens distortion params (calibration.json)",
|
||||
# type=Path,
|
||||
default=None,
|
||||
action=CameraAction)
|
||||
tracker_parser.add_argument("--save-for-training",
|
||||
help="Specify the path in which to save",
|
||||
type=Path,
|
||||
|
@ -246,6 +292,15 @@ tracker_parser.add_argument("--detector",
|
|||
tracker_parser.add_argument("--smooth-tracks",
|
||||
help="Smooth the tracker tracks before sending them to the predictor",
|
||||
action='store_true')
|
||||
tracker_parser.add_argument("--frame-width",
|
||||
help="width of the frames",
|
||||
type=int,
|
||||
default=1280)
|
||||
tracker_parser.add_argument("--frame-height",
|
||||
help="height of the frames",
|
||||
type=int,
|
||||
default=720)
|
||||
|
||||
|
||||
# Renderer
|
||||
|
||||
|
|
|
@ -32,6 +32,14 @@ class DetectionState(IntFlag):
|
|||
return cls.Confirmed
|
||||
raise RuntimeError("Should not run into Deleted entries here")
|
||||
|
||||
class Camera:
|
||||
def __init__(self, mtx, dist, w, h):
|
||||
self.mtx = mtx
|
||||
self.dist = dist
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.newcameramtx, self.roi = cv2.getOptimalNewCameraMatrix(mtx, dist, (w,h), 1, (w,h))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
|
@ -83,22 +91,28 @@ class Track:
|
|||
predictor_history: Optional[list] = None # in image space
|
||||
predictions: Optional[list] = None
|
||||
|
||||
def get_projected_history(self, H) -> np.array:
|
||||
def get_projected_history(self, H, camera: Optional[Camera]= None) -> np.array:
|
||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||
|
||||
# TODO)) Undistort points before perspective transform
|
||||
if len(foot_coordinates):
|
||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
if camera:
|
||||
coords = cv2.undistortPoints(np.array([foot_coordinates]).astype('float32'), camera.mtx, camera.dist, None, camera.newcameramtx)
|
||||
coords = cv2.perspectiveTransform(np.array(coords),H)
|
||||
return coords.reshape((coords.shape[0],2))
|
||||
else:
|
||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
return coords[0]
|
||||
return np.array([])
|
||||
|
||||
def get_projected_history_as_dict(self, H) -> dict:
|
||||
coords = self.get_projected_history(H)
|
||||
def get_projected_history_as_dict(self, H, camera: Optional[Camera]= None) -> dict:
|
||||
coords = self.get_projected_history(H, camera)
|
||||
return [{"x":c[0], "y":c[1]} for c in coords]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
index: int
|
||||
|
@ -106,6 +120,7 @@ class Frame:
|
|||
time: float= field(default_factory=lambda: time.time())
|
||||
tracks: Optional[dict[str, Track]] = None
|
||||
H: Optional[np.array] = None
|
||||
camera: Optional[Camera] = None
|
||||
|
||||
def aslist(self) -> [dict]:
|
||||
return { t.track_id:
|
||||
|
@ -120,6 +135,13 @@ class Frame:
|
|||
} for t in self.tracks.values()
|
||||
}
|
||||
|
||||
def video_src_from_config(config):
|
||||
if config.video_loop:
|
||||
video_srcs: Iterable[Path] = cycle(config.video_src)
|
||||
else:
|
||||
video_srcs: Iterable[Path] = config.video_src
|
||||
return video_srcs
|
||||
|
||||
class FrameEmitter:
|
||||
'''
|
||||
Emit frame in a separate threat so they can be throttled,
|
||||
|
@ -137,10 +159,7 @@ class FrameEmitter:
|
|||
|
||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||
|
||||
if self.config.video_loop:
|
||||
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
||||
else:
|
||||
self.video_srcs: [Path] = self.config.video_src
|
||||
self.video_srcs: video_src_from_config(self.config)
|
||||
|
||||
|
||||
def emit_video(self):
|
||||
|
@ -151,8 +170,8 @@ class FrameEmitter:
|
|||
# numeric input is a CV camera
|
||||
video = cv2.VideoCapture(int(str(video_path)))
|
||||
# TODO: make config variables
|
||||
video.set(cv2.CAP_PROP_FRAME_WIDTH, int(1280))
|
||||
video.set(cv2.CAP_PROP_FRAME_HEIGHT, int(720))
|
||||
video.set(cv2.CAP_PROP_FRAME_WIDTH, int(self.config.frame_width))
|
||||
video.set(cv2.CAP_PROP_FRAME_HEIGHT, int(self.config.frame_height))
|
||||
print("exposure!", video.get(cv2.CAP_PROP_AUTO_EXPOSURE))
|
||||
video.set(cv2.CAP_PROP_FPS, 5)
|
||||
else:
|
||||
|
@ -198,7 +217,7 @@ class FrameEmitter:
|
|||
# hack to mask out area
|
||||
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
||||
|
||||
frame = Frame(index=i, img=img, H=video_H)
|
||||
frame = Frame(index=i, img=img, H=self.config.H, camera=self.config.camera)
|
||||
# TODO: this is very dirty, need to find another way.
|
||||
# perhaps multiprocessing Array?
|
||||
self.frame_sock.send(pickle.dumps(frame))
|
||||
|
|
|
@ -9,7 +9,8 @@ import time
|
|||
from trap.config import parser
|
||||
from trap.frame_emitter import run_frame_emitter
|
||||
from trap.prediction_server import run_prediction_server
|
||||
from trap.renderer import run_renderer
|
||||
from trap.preview_renderer import run_preview_renderer
|
||||
from trap.animation_renderer import run_animation_renderer
|
||||
from trap.socket_forwarder import run_ws_forwarder
|
||||
from trap.tracker import run_tracker
|
||||
|
||||
|
@ -75,7 +76,10 @@ def start():
|
|||
|
||||
if args.render_file or args.render_url or args.render_window:
|
||||
procs.append(
|
||||
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||
ExceptionHandlingProcess(target=run_preview_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||
)
|
||||
procs.append(
|
||||
ExceptionHandlingProcess(target=run_animation_renderer, kwargs={'config': args, 'is_running': isRunning}, name='map_renderer')
|
||||
)
|
||||
|
||||
if not args.bypass_prediction:
|
||||
|
|
|
@ -269,7 +269,7 @@ class PredictionServer:
|
|||
|
||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H, self.config.camera)]
|
||||
history = np.array(history)
|
||||
x = history[:, 0]
|
||||
y = history[:, 1]
|
||||
|
|
|
@ -10,7 +10,7 @@ from multiprocessing import Event
|
|||
from multiprocessing.synchronize import Event as BaseEvent
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import json
|
||||
import pyglet
|
||||
import pyglet.event
|
||||
import zmq
|
||||
|
@ -18,15 +18,17 @@ import tempfile
|
|||
from pathlib import Path
|
||||
import shutil
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from pyglet import shapes
|
||||
from PIL import Image
|
||||
|
||||
from trap.frame_emitter import DetectionState, Frame, Track
|
||||
from trap.frame_emitter import DetectionState, Frame, Track, Camera
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger("trap.renderer")
|
||||
logger = logging.getLogger("trap.preview")
|
||||
|
||||
class FrameAnimation:
|
||||
def __init__(self, frame: Frame):
|
||||
|
@ -55,32 +57,42 @@ def relativePointToPolar(origin, point) -> tuple[float, float]:
|
|||
def relativePolarToPoint(origin, r, angle) -> tuple[float, float]:
|
||||
return r * np.cos(angle) + origin[0], r * np.sin(angle) + origin[1]
|
||||
|
||||
PROJECTION_IMG = 0
|
||||
PROJECTION_UNDISTORT = 1
|
||||
PROJECTION_MAP = 2
|
||||
PROJECTION_PROJECTOR = 4
|
||||
|
||||
class DrawnTrack:
|
||||
def __init__(self, track_id, track: Track, renderer: Renderer, H):
|
||||
def __init__(self, track_id, track: Track, renderer: PreviewRenderer, H, draw_projection = PROJECTION_IMG, camera: Optional[Camera] = None):
|
||||
# self.created_at = time.time()
|
||||
self.draw_projection = draw_projection
|
||||
self.update_at = self.created_at = time.time()
|
||||
self.track_id = track_id
|
||||
self.renderer = renderer
|
||||
self.camera = camera
|
||||
self.H = H # TODO)) Move H to Camera object
|
||||
self.set_track(track, H)
|
||||
self.drawn_positions = []
|
||||
self.drawn_predictions = []
|
||||
self.shapes: list[pyglet.shapes.Line] = []
|
||||
self.pred_shapes: list[list[pyglet.shapes.Line]] = []
|
||||
|
||||
def set_track(self, track: Track, H):
|
||||
def set_track(self, track: Track, H = None):
|
||||
self.update_at = time.time()
|
||||
|
||||
self.track = track
|
||||
self.H = H
|
||||
self.coords = [d.get_foot_coords() for d in track.history]
|
||||
# self.H = H
|
||||
self.coords = [d.get_foot_coords() for d in track.history] if self.draw_projection == PROJECTION_IMG else track.get_projected_history(self.H, self.camera)
|
||||
|
||||
# perhaps only do in constructor:
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
pred_coords = []
|
||||
for pred_i, pred in enumerate(track.predictions):
|
||||
pred_coords.append(cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0].tolist())
|
||||
if self.draw_projection == PROJECTION_IMG:
|
||||
for pred_i, pred in enumerate(track.predictions):
|
||||
pred_coords.append(cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0].tolist())
|
||||
elif self.draw_projection == PROJECTION_MAP:
|
||||
pred_coords = [pred for pred in track.predictions]
|
||||
|
||||
self.pred_coords = pred_coords
|
||||
# color = (128,0,128) if pred_i else (128,
|
||||
|
@ -232,7 +244,7 @@ class FrameWriter:
|
|||
|
||||
|
||||
|
||||
class Renderer:
|
||||
class PreviewRenderer:
|
||||
def __init__(self, config: Namespace, is_running: BaseEvent):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
@ -241,7 +253,8 @@ class Renderer:
|
|||
self.prediction_sock = context.socket(zmq.SUB)
|
||||
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
||||
# self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
||||
self.prediction_sock.connect(config.zmq_prediction_addr)
|
||||
|
||||
self.tracker_sock = context.socket(zmq.SUB)
|
||||
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
|
@ -253,14 +266,23 @@ class Renderer:
|
|||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.frame_sock.connect(config.zmq_frame_addr)
|
||||
|
||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||
|
||||
# TODO)) Move loading H to config.py
|
||||
# if self.config.homography.suffix == '.json':
|
||||
# with self.config.homography.open('r') as fp:
|
||||
# self.H = np.array(json.load(fp))
|
||||
# else:
|
||||
# self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||
# print('h', self.config.H)
|
||||
self.H = self.config.H
|
||||
|
||||
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
self.fps = 60
|
||||
self.frame_size = (1280,720)
|
||||
self.frame_size = (self.config.frame_width,self.config.frame_height)
|
||||
self.hide_stats = False
|
||||
self.out_writer = self.start_writer() if self.config.render_file else None
|
||||
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
||||
|
@ -772,6 +794,6 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, con
|
|||
return img
|
||||
|
||||
|
||||
def run_renderer(config: Namespace, is_running: BaseEvent):
|
||||
renderer = Renderer(config, is_running)
|
||||
def run_preview_renderer(config: Namespace, is_running: BaseEvent):
|
||||
renderer = PreviewRenderer(config, is_running)
|
||||
renderer.run()
|
72
trap/tools.py
Normal file
72
trap/tools.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from trap.config import parser
|
||||
from trap.frame_emitter import video_src_from_config, Frame
|
||||
from trap.tracker import DETECTOR_YOLOv8, _yolov8_track, Track, TrainingDataWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import logging
|
||||
import cv2
|
||||
from typing import List, Iterable
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
import tqdm
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
logger = logging.getLogger('tools')
|
||||
|
||||
def tracker_preprocess():
|
||||
video_srcs = video_src_from_config(config)
|
||||
if not hasattr(config, "H"):
|
||||
print("Set homography file with --homography param")
|
||||
return
|
||||
|
||||
if config.detector != DETECTOR_YOLOv8:
|
||||
print("Only YOLO for now...")
|
||||
return
|
||||
|
||||
model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||
|
||||
with TrainingDataWriter(config.save_for_training) as writer:
|
||||
for video_nr, video_path in enumerate(video_srcs):
|
||||
logger.info(f"Play from '{str(video_path)}'")
|
||||
video = cv2.VideoCapture(str(video_path))
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
i = 0
|
||||
if config.video_offset:
|
||||
logger.info(f"Start at frame {config.video_offset}")
|
||||
video.set(cv2.CAP_PROP_POS_FRAMES, config.video_offset)
|
||||
i = config.video_offset
|
||||
|
||||
bar = tqdm.tqdm()
|
||||
tracks = defaultdict(lambda: Track())
|
||||
|
||||
while True:
|
||||
bar.update()
|
||||
ret, img = video.read()
|
||||
i+=1
|
||||
|
||||
# seek to 0 if video has finished. Infinite loop
|
||||
if not ret:
|
||||
# now loading multiple files
|
||||
break
|
||||
|
||||
frame = Frame(index=bar.n, img=img, H=config.H, camera=config.camera)
|
||||
|
||||
detections = _yolov8_track(frame, model, classes=[0])
|
||||
# detections = _yolov8_track(frame, model, imgsz=1440, classes=[0])
|
||||
|
||||
bar.set_description(f"[{video_nr}/{len(video_srcs)}] [{i}/{frame_count}] {str(video_path)} -- Detections {len(detections)}: {[d.conf for d in detections]}")
|
||||
|
||||
for detection in detections:
|
||||
track = tracks[detection.track_id]
|
||||
track.track_id = detection.track_id # for new tracks
|
||||
track.history.append(detection) # add to history
|
||||
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in tracks.values() if t.track_id in active_track_ids}
|
||||
|
||||
writer.add(frame, active_tracks.values())
|
||||
|
||||
logger.info("Done!")
|
291
trap/tracker.py
291
trap/tracker.py
|
@ -8,7 +8,7 @@ from multiprocessing import Event
|
|||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -47,6 +47,87 @@ DETECTOR_YOLOv8 = 'ultralytics'
|
|||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||
|
||||
|
||||
def _yolov8_track(frame: Frame, model: YOLO, **kwargs) -> List[Detection]:
|
||||
|
||||
results: List[YOLOResult] = list(model.track(frame.img, persist=True, tracker="bytetrack.yaml", verbose=False, **kwargs))
|
||||
if results[0].boxes is None or results[0].boxes.id is None:
|
||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||
return []
|
||||
|
||||
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame.index) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||
|
||||
|
||||
class TrainingDataWriter:
|
||||
def __init__(self, training_path = Optional[Path]):
|
||||
if training_path is None:
|
||||
self.path = None
|
||||
return
|
||||
|
||||
if not isinstance(training_path, Path):
|
||||
raise ValueError("save-for-training should be a path")
|
||||
if not training_path.exists():
|
||||
logger.info(f"Making path for training data: {training_path}")
|
||||
training_path.mkdir(parents=True, exist_ok=False)
|
||||
else:
|
||||
logger.warning(f"Path for training-data exists: {training_path}. Continuing assuming that's ok.")
|
||||
|
||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
|
||||
self.path = training_path
|
||||
|
||||
def __enter__(self):
|
||||
if self.path:
|
||||
self.training_fp = open(self.path / 'all.txt', 'w')
|
||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
self.csv = csv.DictWriter(self.training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def add(self, frame: Frame, tracks: List[Track]):
|
||||
if not self.path:
|
||||
# skip if disabled
|
||||
return
|
||||
|
||||
self.csv.writerows([{
|
||||
'frame_id': round(frame.index * 10., 1), # not really time
|
||||
'track_id': t.track_id,
|
||||
'l': float(t.history[-1].l), # to float, so we're sure it's not a torch.tensor()
|
||||
't': float(t.history[-1].t),
|
||||
'w': float(t.history[-1].w),
|
||||
'h': float(t.history[-1].h),
|
||||
'x': t.get_projected_history(frame.H, frame.camera)[-1][0],
|
||||
'y': t.get_projected_history(frame.H, frame.camera)[-1][1],
|
||||
'state': t.history[-1].state.value
|
||||
# only keep _actual_detections, no lost entries
|
||||
} for t in tracks
|
||||
# if t.history[-1].state != DetectionState.Lost
|
||||
])
|
||||
self.count += len(tracks)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
# ... ignore exception (type, value, traceback)
|
||||
if not self.path:
|
||||
return
|
||||
|
||||
self.training_fp.close()
|
||||
lines = {
|
||||
'train': int(self.count * .8),
|
||||
'val': int(self.count * .12),
|
||||
'test': int(self.count * .08),
|
||||
}
|
||||
logger.info(f"Splitting gathered data from {self.training_fp.name}")
|
||||
with open(self.training_fp.name, 'r') as source_fp:
|
||||
for name, line_nrs in lines.items():
|
||||
dir_path = self.path / name
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
file = dir_path / 'tracked.txt'
|
||||
logger.debug(f"- Write {line_nrs} lines to {file}")
|
||||
with file.open('w') as target_fp:
|
||||
for i in range(line_nrs):
|
||||
target_fp.write(source_fp.readline())
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Tracker:
|
||||
|
@ -98,14 +179,14 @@ class Tracker:
|
|||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||
)
|
||||
elif self.config.detector == DETECTOR_YOLOv8:
|
||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt', classes=0)
|
||||
else:
|
||||
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||
|
||||
|
||||
# homography = list(source.glob('*img2world.txt'))[0]
|
||||
|
||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||
self.H = self.config.H
|
||||
|
||||
if self.config.smooth_tracks:
|
||||
logger.info("Smoother enabled")
|
||||
|
@ -120,155 +201,117 @@ class Tracker:
|
|||
def track(self):
|
||||
prev_run_time = 0
|
||||
|
||||
training_fp = None
|
||||
training_csv = None
|
||||
training_frames = 0
|
||||
# training_fp = None
|
||||
# training_csv = None
|
||||
# training_frames = 0
|
||||
|
||||
if self.config.save_for_training is not None:
|
||||
if not isinstance(self.config.save_for_training, Path):
|
||||
raise ValueError("save-for-training should be a path")
|
||||
if not self.config.save_for_training.exists():
|
||||
logger.info(f"Making path for training data: {self.config.save_for_training}")
|
||||
self.config.save_for_training.mkdir(parents=True, exist_ok=False)
|
||||
else:
|
||||
logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
|
||||
training_fp = open(self.config.save_for_training / 'all.txt', 'w')
|
||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
# if self.config.save_for_training is not None:
|
||||
# if not isinstance(self.config.save_for_training, Path):
|
||||
# raise ValueError("save-for-training should be a path")
|
||||
# if not self.config.save_for_training.exists():
|
||||
# logger.info(f"Making path for training data: {self.config.save_for_training}")
|
||||
# self.config.save_for_training.mkdir(parents=True, exist_ok=False)
|
||||
# else:
|
||||
# logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
|
||||
# training_fp = open(self.config.save_for_training / 'all.txt', 'w')
|
||||
# # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
# training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
|
||||
prev_frame_i = -1
|
||||
|
||||
while self.is_running.is_set():
|
||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||
# so for now, timing should move to emitter
|
||||
# this_run_time = time.time()
|
||||
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
# prev_run_time = time.time()
|
||||
with TrainingDataWriter(self.config.save_for_training) as writer:
|
||||
while self.is_running.is_set():
|
||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||
# so for now, timing should move to emitter
|
||||
# this_run_time = time.time()
|
||||
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
# prev_run_time = time.time()
|
||||
|
||||
zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
if not zmq_ev:
|
||||
logger.warn('skip poll after 2000ms')
|
||||
# when there's no data after timeout, loop so that is_running is checked
|
||||
continue
|
||||
zmq_ev = self.frame_sock.poll(timeout=2000)
|
||||
if not zmq_ev:
|
||||
logger.warn('skip poll after 2000ms')
|
||||
# when there's no data after timeout, loop so that is_running is checked
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||
|
||||
|
||||
prev_frame_i = frame.index
|
||||
# load homography into frame (TODO: should this be done in emitter?)
|
||||
if frame.H is None:
|
||||
# logger.warning('Falling back to default H')
|
||||
# fallback: load configured H
|
||||
frame.H = self.H
|
||||
prev_frame_i = frame.index
|
||||
# load homography into frame (TODO: should this be done in emitter?)
|
||||
if frame.H is None:
|
||||
# logger.warning('Falling back to default H')
|
||||
# fallback: load configured H
|
||||
frame.H = self.H
|
||||
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
|
||||
|
||||
if self.config.detector == DETECTOR_YOLOv8:
|
||||
detections: [Detection] = self._yolov8_track(frame)
|
||||
else :
|
||||
detections: [Detection] = self._resnet_track(frame.img, scale = 1)
|
||||
if self.config.detector == DETECTOR_YOLOv8:
|
||||
detections: [Detection] = _yolov8_track(frame, self.model)
|
||||
else :
|
||||
detections: [Detection] = self._resnet_track(frame.img, scale = 1)
|
||||
|
||||
|
||||
# Store detections into tracklets
|
||||
projected_coordinates = []
|
||||
for detection in detections:
|
||||
track = self.tracks[detection.track_id]
|
||||
track.track_id = detection.track_id # for new tracks
|
||||
# Store detections into tracklets
|
||||
projected_coordinates = []
|
||||
for detection in detections:
|
||||
track = self.tracks[detection.track_id]
|
||||
track.track_id = detection.track_id # for new tracks
|
||||
|
||||
track.history.append(detection) # add to history
|
||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||
track.history.append(detection) # add to history
|
||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||
|
||||
# TODO: hadle occlusions, and dissappearance
|
||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||
# track.history.pop(0)
|
||||
# TODO: hadle occlusions, and dissappearance
|
||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||
# track.history.pop(0)
|
||||
|
||||
|
||||
# trajectories = {}
|
||||
# for detection in detections:
|
||||
# tid = str(detection.track_id)
|
||||
# track = self.tracks[detection.track_id]
|
||||
# coords = track.get_projected_history(self.H) # get full history
|
||||
# trajectories[tid] = {
|
||||
# "id": tid,
|
||||
# "det_conf": detection.conf,
|
||||
# "bbox": detection.to_ltwh(),
|
||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
# }
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.tracks = active_tracks
|
||||
# trajectories = {}
|
||||
# for detection in detections:
|
||||
# tid = str(detection.track_id)
|
||||
# track = self.tracks[detection.track_id]
|
||||
# coords = track.get_projected_history(self.H) # get full history
|
||||
# trajectories[tid] = {
|
||||
# "id": tid,
|
||||
# "det_conf": detection.conf,
|
||||
# "bbox": detection.to_ltwh(),
|
||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
# }
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.tracks = active_tracks
|
||||
|
||||
# if self.config.bypass_prediction:
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# else:
|
||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||
if self.config.smooth_tracks:
|
||||
frame = self.smoother.smooth_frame_tracks(frame)
|
||||
# if self.config.bypass_prediction:
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# else:
|
||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||
if self.config.smooth_tracks:
|
||||
frame = self.smoother.smooth_frame_tracks(frame)
|
||||
|
||||
self.trajectory_socket.send_pyobj(frame)
|
||||
self.trajectory_socket.send_pyobj(frame)
|
||||
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||
|
||||
#TODO calculate fps (also for other loops to see asynchonity)
|
||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||
if training_csv:
|
||||
training_csv.writerows([{
|
||||
'frame_id': round(frame.index * 10., 1), # not really time
|
||||
'track_id': t.track_id,
|
||||
'l': t.history[-1].l,
|
||||
't': t.history[-1].t,
|
||||
'w': t.history[-1].w,
|
||||
'h': t.history[-1].h,
|
||||
'x': t.get_projected_history(frame.H)[-1][0],
|
||||
'y': t.get_projected_history(frame.H)[-1][1],
|
||||
'state': t.history[-1].state.value
|
||||
# only keep _actual_detections, no lost entries
|
||||
} for t in active_tracks.values()
|
||||
# if t.history[-1].state != DetectionState.Lost
|
||||
])
|
||||
training_frames += len(active_tracks)
|
||||
# print(time.time() - start_time)
|
||||
#TODO calculate fps (also for other loops to see asynchonity)
|
||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||
writer.add(frame, active_tracks.values())
|
||||
|
||||
|
||||
if training_fp:
|
||||
training_fp.close()
|
||||
lines = {
|
||||
'train': int(training_frames * .8),
|
||||
'val': int(training_frames * .12),
|
||||
'test': int(training_frames * .08),
|
||||
}
|
||||
logger.info(f"Splitting gathered data from {training_fp.name}")
|
||||
with open(training_fp.name, 'r') as source_fp:
|
||||
for name, line_nrs in lines.items():
|
||||
dir_path = self.config.save_for_training / name
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
file = dir_path / 'tracked.txt'
|
||||
logger.debug(f"- Write {line_nrs} lines to {file}")
|
||||
with file.open('w') as target_fp:
|
||||
for i in range(line_nrs):
|
||||
target_fp.write(source_fp.readline())
|
||||
|
||||
logger.info('Stopping')
|
||||
|
||||
def _yolov8_track(self, frame: Frame,) -> [Detection]:
|
||||
results: [YOLOResult] = self.model.track(frame.img, persist=True, tracker="bytetrack.yaml", verbose=False)
|
||||
if results[0].boxes is None or results[0].boxes.id is None:
|
||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||
return []
|
||||
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame.index) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||
|
||||
def _resnet_track(self, img, scale: float = 1) -> [Detection]:
|
||||
if scale != 1:
|
||||
|
|
Loading…
Reference in a new issue