trap/trap/laser_renderer.py
2025-03-06 18:38:37 +01:00

285 lines
No EOL
12 KiB
Python

# used for "Forward Referencing of type annotations"
from __future__ import annotations
import time
import ffmpeg
from argparse import Namespace
import datetime
import logging
from multiprocessing import Event
from multiprocessing.synchronize import Event as BaseEvent
import cv2
import numpy as np
import json
import pyglet
import pyglet.event
import zmq
import tempfile
from pathlib import Path
import shutil
import math
from typing import Dict, Iterable, Optional
from pyglet import shapes
from PIL import Image
from trap.frame_emitter import DetectionState, Frame, Track, Camera
from trap.helios import HeliosDAC, HeliosPoint
from trap.preview_renderer import FrameWriter
from trap.tools import draw_track, draw_track_predictions, draw_track_projected, draw_trackjectron_history, to_point, track_predictions_to_lines
from trap.utils import convert_world_points_to_img_points, convert_world_space_to_img_space
logger = logging.getLogger("trap.laser_renderer")
class LaserRenderer:
def __init__(self, config: Namespace, is_running: BaseEvent):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
# self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_sock.connect(config.zmq_prediction_addr)
self.tracker_sock = context.socket(zmq.SUB)
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.tracker_sock.connect(config.zmq_trajectory_addr)
self.H = self.config.H
self.inv_H = np.linalg.pinv(self.H)
# TODO: get FPS from frame_emitter
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
self.fps = 60
self.frame_size = (self.config.camera.w,self.config.camera.h)
self.first_time: float|None = None
self.frame: Frame|None= None
self.tracker_frame: Frame|None = None
self.prediction_frame: Frame|None = None
self.tracks: Dict[str, Track] = {}
self.predictions: Dict[str, Track] = {}
self.dac = HeliosDAC(debug=False)
logger.info(f"{self.dac.dev}")
logger.info(f"{self.dac.GetName()}")
logger.info(f"{self.dac.getHWVersion()}")
logger.info(f"Helios version: {self.dac.getHWVersion()}")
# self.init_shapes()
# self.init_labels()
def check_frames(self, dt):
new_tracks = False
try:
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
if not self.first_time:
self.first_time = self.frame.time
img = cv2.GaussianBlur(self.frame.img, (15, 15), 0)
img = cv2.flip(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 0)
img = pyglet.image.ImageData(self.frame_size[0], self.frame_size[1], 'RGB', img.tobytes())
# don't draw in batch, so that it is the background
self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg)
self.video_sprite.opacity = 100
except zmq.ZMQError as e:
# idx = frame.index if frame else "NONE"
# logger.debug(f"reuse video frame {idx}")
pass
try:
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
new_tracks = True
except zmq.ZMQError as e:
pass
try:
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
new_tracks = True
except zmq.ZMQError as e:
pass
def run(self, timer_counter):
frame = None
prediction_frame = None
tracker_frame = None
i=0
first_time = None
while self.is_running.is_set():
i+=1
with timer_counter.get_lock():
timer_counter.value+=1
try:
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
for track_id, track in prediction_frame.tracks.items():
prediction_id = f"{track_id}-{track.history[-1].frame_nr}"
self.predictions[prediction_id] = track
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
try:
tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
for track_id, track in tracker_frame.tracks.items():
self.tracks[track_id] = track
except zmq.ZMQError as e:
logger.debug(f'reuse tracks')
if tracker_frame is None:
# might need to wait a few iterations before first frame comes available
time.sleep(.1)
continue
if first_time is None:
first_time = tracker_frame.time
pointlist = render_frame_to_dac(self.dac, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters)
self.dac.newFrame(50000, pointlist)
# clear out old tracks & predictions:
for track_id, track in list(self.tracks.items()):
# TODO)) Migrate to using time() instead of framenr, to detach the two
if get_animation_position(track, tracker_frame) == 1:
self.tracks.pop(track_id)
for prediction_id, track in list(self.predictions.items()):
if get_animation_position(track, tracker_frame) == 1:
self.predictions.pop(prediction_id)
logger.info('Stopping')
# if i>2:
logger.info('stopped')
# colorset = itertools.product([0,255], repeat=3) # but remove white
# colorset = [(0, 0, 0),
# (0, 0, 255),
# (0, 255, 0),
# (0, 255, 255),
# (255, 0, 0),
# (255, 0, 255),
# (255, 255, 0)
# ]
colorset = [
(255,255,100),
(255,100,255),
(100,255,255),
]
# colorset = [
# (0,0,0),
# ]
def get_animation_position(track: Track, current_frame: Frame):
fade_duration = current_frame.camera.fps * 3
diff = current_frame.index - track.history[-1].frame_nr
return max(0, min(1, diff / fade_duration))
# track.history[-1].frame_nr < (current_frame.index - current_frame.camera.fps * 3)
# track.history[-1].frame_nr < (current_frame.index - current_frame.camera.fps * 3)
# Deprecated
def render_frame_to_dac(dac: HeliosDAC, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True) -> np.array:
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
# or https://pyglet.readthedocs.io/en/latest/programming_guide/shapes.html
# and use http://code.astraw.com/projects/motmot/pygarrayimage.html or https://gist.github.com/nkymut/1cb40ea6ae4de0cf9ded7332f1ca0d55
# or https://api.arcade.academy/en/stable/index.html (supports gradient color in line -- "Arcade is built on top of Pyglet and OpenGL.")
pointlist = []
# pointlist.append(HeliosPoint(x,y, dac.palette[cindex],blank=blank))
# all not working:
# if i == 1:
# # thanks to GpG for fixing scaling issue: https://stackoverflow.com/a/39668864
# scale_factor = 1./20 # from 10m to 1000px
# S = np.array([[scale_factor, 0,0],[0,scale_factor,0 ],[ 0,0,1 ]])
# new_H = S * self.H * np.linalg.inv(S)
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
# cv2.rectangle(img, (0,0), (img.shape[1],25), (0,0,0), -1)
c = dac.palette[4] # Green
pointlist.append(HeliosPoint(10,10, c,blank=False))
pointlist.append(HeliosPoint(10,100, c,blank=False))
pointlist.append(HeliosPoint(100,100, c,blank=False))
pointlist.append(HeliosPoint(100,10, c,blank=False))
pointlist.append(HeliosPoint(10,10, c,blank=True))
if not tracker_frame:
c = dac.palette[3] # yellow
pointlist.append(HeliosPoint(110,10, c,blank=False))
pointlist.append(HeliosPoint(110,100, c,blank=False))
pointlist.append(HeliosPoint(200,100, c,blank=False))
pointlist.append(HeliosPoint(200,10, c,blank=False))
pointlist.append(HeliosPoint(110,10, c,blank=True))
else:
for track_id, track in tracks.items():
inv_H = np.linalg.pinv(tracker_frame.H)
history = track.get_projected_history(camera=config.camera)
history = convert_world_points_to_img_points(history)
# point_color = bgr_colors[color_index % len(bgr_colors)]
points = np.rint(history.reshape((-1,1,2))).astype(np.int32)
for i, point in enumerate(points):
blank = i+1 == len(points) # last point blank
pointlist.append(HeliosPoint(point[0][0], point[0][1], dac.palette[2], blank=blank))
# draw_track_projected(img, track, int(track_id), config.camera, convert_world_points_to_img_points)
if not prediction_frame:
c = dac.palette[7] # magenta
pointlist.append(HeliosPoint(210,10, c,blank=False))
pointlist.append(HeliosPoint(210,100, c,blank=False))
pointlist.append(HeliosPoint(300,100, c,blank=False))
pointlist.append(HeliosPoint(300,10, c,blank=False))
pointlist.append(HeliosPoint(210,10, c,blank=True))
# cv2.putText(img, f"Waiting for prediction...", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
# continue
else:
for track_id, track in predictions.items():
inv_H = np.linalg.pinv(prediction_frame.H)
# For debugging:
# draw_trackjectron_history(img, track, int(track.track_id), convert_world_points_to_img_points)
anim_position = get_animation_position(track, tracker_frame)
lines = track_predictions_to_lines(track, config.camera, anim_position)
if not lines:
continue
lines = [convert_world_points_to_img_points(points) for points in lines]
# cv2 only draws to integer coordinates
lines = [np.rint(points).astype(int) for points in lines]
# draw in a single pass
# line_points = line_points.reshape((1, -1,1,2))
for line in lines:
for i, point in enumerate(line):
blank = i+1 == len(points) # last point blank
pointlist.append(HeliosPoint(point[0], point[1], dac.palette[4], blank=blank))
# draw_track_predictions(img, track, int(track.track_id)+1, config.camera, convert_world_points_to_img_points, anim_position=anim_position, as_clusters=as_clusters)
# cv2.putText(img, f"{len(track.predictor_history) if track.predictor_history else 'none'}", to_point(track.history[0].get_foot_coords()), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,255), 1)
return pointlist
def run_laser_renderer(config: Namespace, is_running: BaseEvent, timer_counter):
renderer = LaserRenderer(config, is_running)
renderer.run(timer_counter)