Compare commits

..

No commits in common. "7c05c060c345068b639ac9378db346832373cd95" and "cd8a06a53bc0da129a2fba26a1b56b93d7fa85fb" have entirely different histories.

11 changed files with 363 additions and 1550 deletions

410
.gitignore vendored
View file

@ -1,410 +0,0 @@
.idea/
OUT/
EXPERIMENTS/
## Core latex/pdflatex auxiliary files:
*.aux
*.lof
*.log
*.lot
*.fls
*.out
*.toc
*.fmt
*.fot
*.cb
*.cb2
.*.lb
## Intermediate documents:
*.dvi
*.xdv
*-converted-to.*
# these rules might exclude image files for figures etc.
# *.ps
# *.eps
# *.pdf
## Generated if empty string is given at "Please type another file name for output:"
.pdf
## Bibliography auxiliary files (bibtex/biblatex/biber):
*.bbl
*.bcf
*.blg
*-blx.aux
*-blx.bib
*.run.xml
## Build tool auxiliary files:
*.fdb_latexmk
*.synctex
*.synctex(busy)
*.synctex.gz
*.synctex.gz(busy)
*.pdfsync
## Build tool directories for auxiliary files
# latexrun
latex.out/
## Auxiliary and intermediate files from other packages:
# algorithms
*.alg
*.loa
# achemso
acs-*.bib
# amsthm
*.thm
# beamer
*.nav
*.pre
*.snm
*.vrb
# changes
*.soc
# comment
*.cut
# cprotect
*.cpt
# elsarticle (documentclass of Elsevier journals)
*.spl
# endnotes
*.ent
# fixme
*.lox
# feynmf/feynmp
*.mf
*.mp
*.t[1-9]
*.t[1-9][0-9]
*.tfm
#(r)(e)ledmac/(r)(e)ledpar
*.end
*.?end
*.[1-9]
*.[1-9][0-9]
*.[1-9][0-9][0-9]
*.[1-9]R
*.[1-9][0-9]R
*.[1-9][0-9][0-9]R
*.eledsec[1-9]
*.eledsec[1-9]R
*.eledsec[1-9][0-9]
*.eledsec[1-9][0-9]R
*.eledsec[1-9][0-9][0-9]
*.eledsec[1-9][0-9][0-9]R
# glossaries
*.acn
*.acr
*.glg
*.glo
*.gls
*.glsdefs
*.lzo
*.lzs
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
# *.ist
# gnuplottex
*-gnuplottex-*
# gregoriotex
*.gaux
*.gtex
# htlatex
*.4ct
*.4tc
*.idv
*.lg
*.trc
*.xref
# hyperref
*.brf
# knitr
*-concordance.tex
# TODO Comment the next line if you want to keep your tikz graphics files
*.tikz
*-tikzDictionary
# listings
*.lol
# luatexja-ruby
*.ltjruby
# makeidx
*.idx
*.ilg
*.ind
# minitoc
*.maf
*.mlf
*.mlt
*.mtc[0-9]*
*.slf[0-9]*
*.slt[0-9]*
*.stc[0-9]*
# minted
_minted*
*.pyg
# morewrites
*.mw
# nomencl
*.nlg
*.nlo
*.nls
# pax
*.pax
# pdfpcnotes
*.pdfpc
# sagetex
*.sagetex.sage
*.sagetex.py
*.sagetex.scmd
# scrwfile
*.wrt
# sympy
*.sout
*.sympy
sympy-plots-for-*.tex/
# pdfcomment
*.upa
*.upb
# pythontex
*.pytxcode
pythontex-files-*/
# tcolorbox
*.listing
# thmtools
*.loe
# TikZ & PGF
*.dpth
*.md5
*.auxlock
# todonotes
*.tdo
# vhistory
*.hst
*.ver
# easy-todo
*.lod
# xcolor
*.xcp
# xmpincl
*.xmpi
# xindy
*.xdy
# xypic precompiled matrices and outlines
*.xyc
*.xyd
# endfloat
*.ttt
*.fff
# Latexian
TSWLatexianTemp*
## Editors:
# WinEdt
*.bak
*.sav
# Texpad
.texpadtmp
# LyX
*.lyx~
# Kile
*.backup
# gummi
.*.swp
# KBibTeX
*~[0-9]*
# TeXnicCenter
*.tps
# auto folder when using emacs and auctex
./auto/*
*.el
# expex forward references with \gathertags
*-tags.tex
# standalone packages
*.sta
# Makeindex log files
*.lpz
logs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

55
poetry.lock generated
View file

@ -1817,9 +1817,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
]
[[package]]
@ -1939,8 +1939,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -1970,21 +1970,6 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.8.0)"]
[[package]]
name = "pandas_helper_calc"
version = "0.0.1"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/scls19fr/pandas-helper-calc"
reference = "HEAD"
resolved_reference = "22df480f09c0fa96548833f9dee8f9128512641b"
[[package]]
name = "pandocfilters"
version = "1.5.0"
@ -2976,24 +2961,6 @@ all = ["numpy", "pytest", "pytest-cov"]
test = ["pytest", "pytest-cov"]
vectorized = ["numpy"]
[[package]]
name = "simdkalman"
version = "1.0.4"
description = "Kalman filters vectorized as Single Instruction, Multiple Data"
optional = false
python-versions = "*"
files = [
{file = "simdkalman-1.0.4-py2.py3-none-any.whl", hash = "sha256:fc2c6b9e540e0a26b39d087e78623d3c1e8c6677abf5d91111f5d49e328e1668"},
]
[package.dependencies]
numpy = ">=1.9.0"
[package.extras]
dev = ["check-manifest"]
docs = ["sphinx"]
test = ["pylint"]
[[package]]
name = "six"
version = "1.16.0"
@ -3333,22 +3300,6 @@ tqdm = "^4.65.0"
type = "directory"
url = "../Trajectron-plus-plus"
[[package]]
name = "tsmoothie"
version = "1.0.5"
description = "A python library for timeseries smoothing and outlier detection in a vectorized way."
optional = false
python-versions = ">=3"
files = [
{file = "tsmoothie-1.0.5-py3-none-any.whl", hash = "sha256:dedf8d8e011562824abe41783bf33e1b9ee1424bc572853bb82408743316a90e"},
{file = "tsmoothie-1.0.5.tar.gz", hash = "sha256:d83fa0ccae32bde7b904d9581ebf137e8eb18629cc3563d7379ca5f92461f6f5"},
]
[package.dependencies]
numpy = "*"
scipy = "*"
simdkalman = "*"
[[package]]
name = "types-python-dateutil"
version = "2.8.19.14"
@ -3517,4 +3468,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10,<3.12,"
content-hash = "66f062f9db921cfa83e576288d09fd9b959780eb189d95765934ae9a6769f200"
content-hash = "c9d4fe6a1d054a835a689cee011753b900b696aa8a06b81aa7a10afc24a8bc70"

View file

@ -29,8 +29,6 @@ ultralytics = "^8.0.200"
ffmpeg-python = "^0.2.0"
torchreid = "^0.2.5"
gdown = "^4.7.1"
pandas-helper-calc = {git = "https://github.com/scls19fr/pandas-helper-calc"}
tsmoothie = "^1.0.5"
[build-system]
requires = ["poetry-core"]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,6 +1,5 @@
from argparse import Namespace
from dataclasses import dataclass, field
from enum import IntFlag
from itertools import cycle
import logging
from multiprocessing import Event
@ -13,25 +12,9 @@ import numpy as np
import cv2
import zmq
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
logger = logging.getLogger('trap.frame_emitter')
class DetectionState(IntFlag):
Tentative = 1 # state before n_init (see DeepsortTrack)
Confirmed = 2 # after tentative
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
@classmethod
def from_deepsort_track(cls, track: DeepsortTrack):
if track.state == DeepsortTrackState.Tentative:
return cls.Tentative
if track.state == DeepsortTrackState.Confirmed:
if track.time_since_update > 0:
return cls.Lost
return cls.Confirmed
raise RuntimeError("Should not run into Deleted entries here")
@dataclass
class Detection:
@ -41,27 +24,13 @@ class Detection:
w: int # width - image space
h: int # height - image space
conf: float # object detector probablity
state: DetectionState
def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod
def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack))
def get_scaled(self, scale: float = 1):
if scale == 1:
return self
return Detection(
self.track_id,
self.l*scale,
self.t*scale,
self.w*scale,
self.h*scale,
self.conf,
self.state)
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
def to_ltwh(self):
return (int(self.l), int(self.t), int(self.w), int(self.h))
@ -70,7 +39,6 @@ class Detection:
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
@dataclass
class Track:
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
@ -95,7 +63,6 @@ class Track:
return [{"x":c[0], "y":c[1]} for c in coords]
@dataclass
class Frame:
index: int
@ -104,19 +71,6 @@ class Frame:
tracks: Optional[dict[str, Track]] = None
H: Optional[np.array] = None
def aslist(self) -> [dict]:
return { t.track_id:
{
'id': t.track_id,
'history': t.get_projected_history(self.H).tolist(),
'det_conf': t.history[-1].conf,
# 'det_conf': trajectory_data[node.id]['det_conf'],
# 'bbox': trajectory_data[node.id]['bbox'],
# 'history': history.tolist(),
'predictions': t.predictions
} for t in self.tracks.values()
}
class FrameEmitter:
'''
Emit frame in a separate threat so they can be throttled,
@ -141,29 +95,15 @@ class FrameEmitter:
def emit_video(self):
i = 0
for video_path in self.video_srcs:
logger.info(f"Play from '{str(video_path)}'")
video = cv2.VideoCapture(str(video_path))
fps = video.get(cv2.CAP_PROP_FPS)
target_frame_duration = 1./fps
frame_duration = 1./fps
logger.info(f"Emit frames at {fps} fps")
if '-' in video_path.stem:
path_stem = video_path.stem[:video_path.stem.rfind('-')]
else:
path_stem = video_path.stem
path_stem += "-homography"
homography_path = video_path.with_stem(path_stem).with_suffix('.txt')
logger.info(f'check homography file {homography_path}')
if homography_path.exists():
logger.info(f'Found custom homography file! Using {homography_path}')
video_H = np.loadtxt(homography_path, delimiter=',')
else:
video_H = None
prev_time = time.time()
i = 0
while self.is_running.is_set():
ret, img = video.read()
@ -180,19 +120,19 @@ class FrameEmitter:
# hack to mask out area
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
frame = Frame(index=i, img=img, H=video_H)
frame = Frame(index=i, img=img)
# TODO: this is very dirty, need to find another way.
# perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame))
# defer next loop
now = time.time()
time_diff = (now - prev_time)
if time_diff < target_frame_duration:
time.sleep(target_frame_duration - time_diff)
now += target_frame_duration - time_diff
prev_time = now
new_frame_time = time.time()
time_diff = (new_frame_time - prev_time)
if time_diff < frame_duration:
time.sleep(frame_duration - time_diff)
new_frame_time += frame_duration - time_diff
else:
prev_time = new_frame_time
i += 1

View file

@ -243,7 +243,7 @@ class PredictionServer:
if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else:
zmq_ev = self.trajectory_socket.poll(timeout=2000)
zmq_ev = self.trajectory_socket.poll(timeout=3)
if not zmq_ev:
# on no data loop so that is_running is checked
continue
@ -252,7 +252,7 @@ class PredictionServer:
frame: Frame = pickle.loads(data)
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data)
# logger.debug(f"Receive {frame.index}")
logger.debug(f"Receive {frame.index}")
# class FakeNode:
# def __init__(self, node_type: NodeType):
@ -276,12 +276,12 @@ class PredictionServer:
ax = derivative_of(vx, 0.1)
ay = derivative_of(vy, 0.1)
data_dict = {('position', 'x'): x[:], # [-10:-1]
('position', 'y'): y[:], # [-10:-1]
('velocity', 'x'): vx[:], # [-10:-1]
('velocity', 'y'): vy[:], # [-10:-1]
('acceleration', 'x'): ax[:], # [-10:-1]
('acceleration', 'y'): ay[:]} # [-10:-1]
data_dict = {('position', 'x'): x[:],
('position', 'y'): y[:],
('velocity', 'x'): vx[:],
('velocity', 'y'): vy[:],
('acceleration', 'x'): ax[:],
('acceleration', 'y'): ay[:]}
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
node_data = pd.DataFrame(data_dict, columns=data_columns)
@ -301,7 +301,7 @@ class PredictionServer:
# TODO: we want to send out empty result...
# And want to update the network
# data = json.dumps({})
data = json.dumps({})
self.prediction_socket.send_pyobj(frame)
continue
@ -325,7 +325,7 @@ class PredictionServer:
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=125, # TODO: make variable
prediction_horizon=25, # TODO: make variable
num_samples=5, # TODO: make variable
robot_present_and_future=robot_present_and_future,
full_dist=True)

View file

@ -1,4 +1,3 @@
import time
import ffmpeg
from argparse import Namespace
import datetime
@ -9,7 +8,7 @@ import numpy as np
import zmq
from trap.frame_emitter import DetectionState, Frame
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer")
@ -85,7 +84,7 @@ class Renderer:
while self.is_running.is_set():
i+=1
zmq_ev = self.frame_sock.poll(timeout=2000)
zmq_ev = self.frame_sock.poll(timeout=3)
if not zmq_ev:
# when no data comes in, loop so that is_running is checked
continue
@ -96,32 +95,6 @@ class Renderer:
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
if first_time is None:
first_time = frame.time
decorate_frame(frame, prediction_frame, first_time)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.debug(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.array:
img = frame.img
# all not working:
@ -135,31 +108,27 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.arra
if not prediction_frame:
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
# continue
continue
else:
inv_H = np.linalg.pinv(prediction_frame.H)
for track_id, track in prediction_frame.tracks.items():
if not len(track.history):
continue
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
coords = [d.get_foot_coords() for d in track.history]
confirmations = [d.state == DetectionState.Confirmed for d in track.history]
# logger.warning(f"{coords=}")
for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]]
color = (255,255,255) if confirmations[ci] else (100,100,100)
cv2.line(img, start, end, color, 2, lineType=cv2.LINE_AA)
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
if not track.predictions or not len(track.predictions):
continue
for pred_i, pred in enumerate(track.predictions):
pred_coords = cv2.perspectiveTransform(np.array([pred]), inv_H)[0]
color = (0,0,255) if pred_i else (100,100,100)
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
color = (0,0,255) if pred_i == 1 else (100,100,100)
for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]]
@ -179,20 +148,36 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.arra
cv2.rectangle(img, p1, p2, (255,0,0), 1)
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
if prediction_frame:
# render Δt and Δ frames
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()])}", (580, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()])}", (660, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()])}", (740, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
return img
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.info(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def run_renderer(config: Namespace, is_running: Event):

View file

@ -1,9 +1,7 @@
from argparse import Namespace
import asyncio
import dataclasses
import errno
import json
import logging
from multiprocessing import Event
import subprocess
@ -17,8 +15,6 @@ import tornado.websocket
import zmq
import zmq.asyncio
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.forwarder")
@ -28,7 +24,7 @@ class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.debug(f"receive msg")
logger.debug(f"recieve msg")
try:
await self.zmq_socket.send_string(message)
@ -116,13 +112,11 @@ class WsRouter:
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
logger.info(f'Publish trajectories on {config.zmq_trajectory_addr}')
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.application = tornado.web.Application(
[
@ -172,16 +166,11 @@ class WsRouter:
logger.info("Starting prediction forwarder")
while self.is_running.is_set():
# timeout so that if no events occur, loop can still stop on is_running
has_event = await self.prediction_socket.poll(timeout=1000)
has_event = await self.prediction_socket.poll(timeout=1)
if has_event:
try:
frame: Frame = await self.prediction_socket.recv_pyobj()
# tacks = [dataclasses.asdict(h) for t in frame.tracks.values() for t.history in t]
msg = json.dumps(frame.aslist())
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
except Exception as e:
logger.exception(e)
# die together:
self.evt_loop.stop()

View file

@ -22,7 +22,7 @@ from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult
from trap.frame_emitter import DetectionState, Frame, Detection, Track
from trap.frame_emitter import Frame, Detection, Track
# Detection = [int, int, int, int, float, int]
# Detections = [Detection]
@ -66,9 +66,6 @@ class Tracker:
# TODO: support removal
self.tracks = defaultdict(lambda: Track())
logger.debug(f"Load tracker: {self.config.detector}")
if self.config.detector == DETECTOR_RETINANET:
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
@ -79,7 +76,7 @@ class Tracker:
self.model.eval()
# Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
)
elif self.config.detector == DETECTOR_MASKRCNN:
@ -90,7 +87,7 @@ class Tracker:
self.model.eval()
# Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
)
elif self.config.detector == DETECTOR_YOLOv8:
@ -123,7 +120,7 @@ class Tracker:
logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
training_fp = open(self.config.save_for_training / 'all.txt', 'w')
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
prev_frame_i = -1
@ -136,12 +133,6 @@ class Tracker:
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
# prev_run_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warn('skip poll after 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
@ -151,9 +142,6 @@ class Tracker:
prev_frame_i = frame.index
# load homography into frame (TODO: should this be done in emitter?)
if frame.H is None:
# logger.warning('Falling back to default H')
# fallback: load configured H
frame.H = self.H
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
@ -162,7 +150,7 @@ class Tracker:
if self.config.detector == DETECTOR_YOLOv8:
detections: [Detection] = self._yolov8_track(frame.img)
else :
detections: [Detection] = self._resnet_track(frame.img, scale = 1)
detections: [Detection] = self._resnet_track(frame.img)
# Store detections into tracklets
@ -213,18 +201,10 @@ class Tracker:
if training_csv:
training_csv.writerows([{
'frame_id': round(frame.index * 10., 1), # not really time
'track_id': t.track_id,
'l': t.history[-1].l,
't': t.history[-1].t,
'w': t.history[-1].w,
'h': t.history[-1].h,
'x': t.get_projected_history(frame.H)[-1][0],
'y': t.get_projected_history(frame.H)[-1][1],
'state': t.history[-1].state.value
# only keep _actual_detections, no lost entries
} for t in active_tracks.values()
# if t.history[-1].state != DetectionState.Lost
])
'track_id': t['id'],
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
} for t in active_tracks.values()])
training_frames += len(active_tracks)
# print(time.time() - start_time)
@ -256,13 +236,10 @@ class Tracker:
return []
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
def _resnet_track(self, img, scale: float = 1) -> [Detection]:
if scale != 1:
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
img = cv2.resize(img, dsize)
def _resnet_track(self, img) -> [Detection]:
detections = self._resnet_detect_persons(img)
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
return [Detection.from_deepsort(t).get_scaled(1/scale) for t in tracks]
return [Detection.from_deepsort(t) for t in tracks]
def _resnet_detect_persons(self, frame) -> [Detection]:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

View file

@ -25,13 +25,12 @@
<script>
// minified https://github.com/joewalnes/reconnecting-websocket
!function (a, b) { "function" == typeof define && define.amd ? define([], b) : "undefined" != typeof module && module.exports ? module.exports = b() : a.ReconnectingWebSocket = b() }(this, function () { function a(b, c, d) { function l(a, b) { var c = document.createEvent("CustomEvent"); return c.initCustomEvent(a, !1, !1, b), c } var e = { debug: !1, automaticOpen: !0, reconnectInterval: 1e3, maxReconnectInterval: 3e4, reconnectDecay: 1.5, timeoutInterval: 2e3 }; d || (d = {}); for (var f in e) this[f] = "undefined" != typeof d[f] ? d[f] : e[f]; this.url = b, this.reconnectAttempts = 0, this.readyState = WebSocket.CONNECTING, this.protocol = null; var h, g = this, i = !1, j = !1, k = document.createElement("div"); k.addEventListener("open", function (a) { g.onopen(a) }), k.addEventListener("close", function (a) { g.onclose(a) }), k.addEventListener("connecting", function (a) { g.onconnecting(a) }), k.addEventListener("message", function (a) { g.onmessage(a) }), k.addEventListener("error", function (a) { g.onerror(a) }), this.addEventListener = k.addEventListener.bind(k), this.removeEventListener = k.removeEventListener.bind(k), this.dispatchEvent = k.dispatchEvent.bind(k), this.open = function (b) { h = new WebSocket(g.url, c || []), b || k.dispatchEvent(l("connecting")), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "attempt-connect", g.url); var d = h, e = setTimeout(function () { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "connection-timeout", g.url), j = !0, d.close(), j = !1 }, g.timeoutInterval); h.onopen = function () { clearTimeout(e), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onopen", g.url), g.protocol = h.protocol, g.readyState = WebSocket.OPEN, g.reconnectAttempts = 0; var d = l("open"); d.isReconnect = b, b = !1, k.dispatchEvent(d) }, h.onclose = function (c) { if (clearTimeout(e), h = null, i) g.readyState = WebSocket.CLOSED, k.dispatchEvent(l("close")); else { g.readyState = WebSocket.CONNECTING; var d = l("connecting"); d.code = c.code, d.reason = c.reason, d.wasClean = c.wasClean, k.dispatchEvent(d), b || j || ((g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onclose", g.url), k.dispatchEvent(l("close"))); var e = g.reconnectInterval * Math.pow(g.reconnectDecay, g.reconnectAttempts); setTimeout(function () { g.reconnectAttempts++, g.open(!0) }, e > g.maxReconnectInterval ? g.maxReconnectInterval : e) } }, h.onmessage = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onmessage", g.url, b.data); var c = l("message"); c.data = b.data, k.dispatchEvent(c) }, h.onerror = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onerror", g.url, b), k.dispatchEvent(l("error")) } }, 1 == this.automaticOpen && this.open(!1), this.send = function (b) { if (h) return (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "send", g.url, b), h.send(b); throw "INVALID_STATE_ERR : Pausing to reconnect websocket" }, this.close = function (a, b) { "undefined" == typeof a && (a = 1e3), i = !0, h && h.close(a, b) }, this.refresh = function () { h && h.close() } } return a.prototype.onopen = function () { }, a.prototype.onclose = function () { }, a.prototype.onconnecting = function () { }, a.prototype.onmessage = function () { }, a.prototype.onerror = function () { }, a.debugAll = !1, a.CONNECTING = WebSocket.CONNECTING, a.OPEN = WebSocket.OPEN, a.CLOSING = WebSocket.CLOSING, a.CLOSED = WebSocket.CLOSED, a });
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script>
<script>
// map the field to coordinates of our dummy tracker
// see test_homography.ipynb for the logic behind these values
const field_range = { x: [-13.092, 15.37], y: [-4.66, 10.624] }
const field_range = { x: [-30, 10], y: [-10, 10] }
// Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
@ -97,17 +96,17 @@
let current_pos = null;
function appendAndSendPositions() {
if (is_moving && current_pos !== null) {
function appendAndSendPositions(){
if(is_moving && current_pos!==null){
// throttled update of tracker on movement
tracker[person_counter].addToHistory(current_pos);
}
for (const person_id in tracker) {
if (person_id != person_counter) { // compare int/str
for(const person_id in tracker){
if(person_id != person_counter){ // compare int/str
// fade out old tracks
tracker[person_id].history.shift()
if (!tracker[person_id].history.length) {
if(!tracker[person_id].history.length){
delete tracker[person_id]
}
}
@ -126,7 +125,7 @@
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
current_pos = position;
tracker[person_counter].addToHistory(current_pos);
// tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker))
});
@ -135,8 +134,8 @@
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
current_pos = position;
tracker[person_counter].addToHistory(current_pos);
trajectory_socket.send(JSON.stringify(tracker))
// tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker))
});
document.addEventListener('mouseup', (e) => {
person_counter++;
@ -171,13 +170,12 @@
ctx.stroke();
}
if (person.hasOwnProperty('predictions') && person.predictions.length > 0) {
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled
person.predictions.forEach((prediction, i) => {
ctx.beginPath()
ctx.lineWidth = 0.2;
ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
ctx.strokeStyle = "#ccaaaa";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
@ -186,33 +184,6 @@
}
ctx.stroke();
});
// average stroke:
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#ff0000";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (let index = 0; index < person.predictions[0].length; index++) {
sum = person.predictions.reduce(
(accumulator, prediction) => ({
"x": accumulator.x + prediction[index][0],
"y": accumulator.y + prediction[index][1],
}),
{ x: 0, y: 0 },
);
avg = { x: sum.x / person.predictions.length, y: sum.y / person.predictions.length }
// console.log(sum, avg)
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(avg)))
}
// for (const position of ) {
// }
ctx.stroke();
}
}
ctx.restore();