Compare commits

..

No commits in common. "7c05c060c345068b639ac9378db346832373cd95" and "cd8a06a53bc0da129a2fba26a1b56b93d7fa85fb" have entirely different histories.

11 changed files with 363 additions and 1550 deletions

410
.gitignore vendored
View file

@ -1,410 +0,0 @@
.idea/
OUT/
EXPERIMENTS/
## Core latex/pdflatex auxiliary files:
*.aux
*.lof
*.log
*.lot
*.fls
*.out
*.toc
*.fmt
*.fot
*.cb
*.cb2
.*.lb
## Intermediate documents:
*.dvi
*.xdv
*-converted-to.*
# these rules might exclude image files for figures etc.
# *.ps
# *.eps
# *.pdf
## Generated if empty string is given at "Please type another file name for output:"
.pdf
## Bibliography auxiliary files (bibtex/biblatex/biber):
*.bbl
*.bcf
*.blg
*-blx.aux
*-blx.bib
*.run.xml
## Build tool auxiliary files:
*.fdb_latexmk
*.synctex
*.synctex(busy)
*.synctex.gz
*.synctex.gz(busy)
*.pdfsync
## Build tool directories for auxiliary files
# latexrun
latex.out/
## Auxiliary and intermediate files from other packages:
# algorithms
*.alg
*.loa
# achemso
acs-*.bib
# amsthm
*.thm
# beamer
*.nav
*.pre
*.snm
*.vrb
# changes
*.soc
# comment
*.cut
# cprotect
*.cpt
# elsarticle (documentclass of Elsevier journals)
*.spl
# endnotes
*.ent
# fixme
*.lox
# feynmf/feynmp
*.mf
*.mp
*.t[1-9]
*.t[1-9][0-9]
*.tfm
#(r)(e)ledmac/(r)(e)ledpar
*.end
*.?end
*.[1-9]
*.[1-9][0-9]
*.[1-9][0-9][0-9]
*.[1-9]R
*.[1-9][0-9]R
*.[1-9][0-9][0-9]R
*.eledsec[1-9]
*.eledsec[1-9]R
*.eledsec[1-9][0-9]
*.eledsec[1-9][0-9]R
*.eledsec[1-9][0-9][0-9]
*.eledsec[1-9][0-9][0-9]R
# glossaries
*.acn
*.acr
*.glg
*.glo
*.gls
*.glsdefs
*.lzo
*.lzs
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
# *.ist
# gnuplottex
*-gnuplottex-*
# gregoriotex
*.gaux
*.gtex
# htlatex
*.4ct
*.4tc
*.idv
*.lg
*.trc
*.xref
# hyperref
*.brf
# knitr
*-concordance.tex
# TODO Comment the next line if you want to keep your tikz graphics files
*.tikz
*-tikzDictionary
# listings
*.lol
# luatexja-ruby
*.ltjruby
# makeidx
*.idx
*.ilg
*.ind
# minitoc
*.maf
*.mlf
*.mlt
*.mtc[0-9]*
*.slf[0-9]*
*.slt[0-9]*
*.stc[0-9]*
# minted
_minted*
*.pyg
# morewrites
*.mw
# nomencl
*.nlg
*.nlo
*.nls
# pax
*.pax
# pdfpcnotes
*.pdfpc
# sagetex
*.sagetex.sage
*.sagetex.py
*.sagetex.scmd
# scrwfile
*.wrt
# sympy
*.sout
*.sympy
sympy-plots-for-*.tex/
# pdfcomment
*.upa
*.upb
# pythontex
*.pytxcode
pythontex-files-*/
# tcolorbox
*.listing
# thmtools
*.loe
# TikZ & PGF
*.dpth
*.md5
*.auxlock
# todonotes
*.tdo
# vhistory
*.hst
*.ver
# easy-todo
*.lod
# xcolor
*.xcp
# xmpincl
*.xmpi
# xindy
*.xdy
# xypic precompiled matrices and outlines
*.xyc
*.xyd
# endfloat
*.ttt
*.fff
# Latexian
TSWLatexianTemp*
## Editors:
# WinEdt
*.bak
*.sav
# Texpad
.texpadtmp
# LyX
*.lyx~
# Kile
*.backup
# gummi
.*.swp
# KBibTeX
*~[0-9]*
# TeXnicCenter
*.tps
# auto folder when using emacs and auctex
./auto/*
*.el
# expex forward references with \gathertags
*-tags.tex
# standalone packages
*.sta
# Makeindex log files
*.lpz
logs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

55
poetry.lock generated
View file

@ -1817,9 +1817,9 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
] ]
[[package]] [[package]]
@ -1939,8 +1939,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -1970,21 +1970,6 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.8.0)"] xml = ["lxml (>=4.8.0)"]
[[package]]
name = "pandas_helper_calc"
version = "0.0.1"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/scls19fr/pandas-helper-calc"
reference = "HEAD"
resolved_reference = "22df480f09c0fa96548833f9dee8f9128512641b"
[[package]] [[package]]
name = "pandocfilters" name = "pandocfilters"
version = "1.5.0" version = "1.5.0"
@ -2976,24 +2961,6 @@ all = ["numpy", "pytest", "pytest-cov"]
test = ["pytest", "pytest-cov"] test = ["pytest", "pytest-cov"]
vectorized = ["numpy"] vectorized = ["numpy"]
[[package]]
name = "simdkalman"
version = "1.0.4"
description = "Kalman filters vectorized as Single Instruction, Multiple Data"
optional = false
python-versions = "*"
files = [
{file = "simdkalman-1.0.4-py2.py3-none-any.whl", hash = "sha256:fc2c6b9e540e0a26b39d087e78623d3c1e8c6677abf5d91111f5d49e328e1668"},
]
[package.dependencies]
numpy = ">=1.9.0"
[package.extras]
dev = ["check-manifest"]
docs = ["sphinx"]
test = ["pylint"]
[[package]] [[package]]
name = "six" name = "six"
version = "1.16.0" version = "1.16.0"
@ -3333,22 +3300,6 @@ tqdm = "^4.65.0"
type = "directory" type = "directory"
url = "../Trajectron-plus-plus" url = "../Trajectron-plus-plus"
[[package]]
name = "tsmoothie"
version = "1.0.5"
description = "A python library for timeseries smoothing and outlier detection in a vectorized way."
optional = false
python-versions = ">=3"
files = [
{file = "tsmoothie-1.0.5-py3-none-any.whl", hash = "sha256:dedf8d8e011562824abe41783bf33e1b9ee1424bc572853bb82408743316a90e"},
{file = "tsmoothie-1.0.5.tar.gz", hash = "sha256:d83fa0ccae32bde7b904d9581ebf137e8eb18629cc3563d7379ca5f92461f6f5"},
]
[package.dependencies]
numpy = "*"
scipy = "*"
simdkalman = "*"
[[package]] [[package]]
name = "types-python-dateutil" name = "types-python-dateutil"
version = "2.8.19.14" version = "2.8.19.14"
@ -3517,4 +3468,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10,<3.12," python-versions = "^3.10,<3.12,"
content-hash = "66f062f9db921cfa83e576288d09fd9b959780eb189d95765934ae9a6769f200" content-hash = "c9d4fe6a1d054a835a689cee011753b900b696aa8a06b81aa7a10afc24a8bc70"

View file

@ -29,8 +29,6 @@ ultralytics = "^8.0.200"
ffmpeg-python = "^0.2.0" ffmpeg-python = "^0.2.0"
torchreid = "^0.2.5" torchreid = "^0.2.5"
gdown = "^4.7.1" gdown = "^4.7.1"
pandas-helper-calc = {git = "https://github.com/scls19fr/pandas-helper-calc"}
tsmoothie = "^1.0.5"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,6 +1,5 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntFlag
from itertools import cycle from itertools import cycle
import logging import logging
from multiprocessing import Event from multiprocessing import Event
@ -13,25 +12,9 @@ import numpy as np
import cv2 import cv2
import zmq import zmq
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
logger = logging.getLogger('trap.frame_emitter') logger = logging.getLogger('trap.frame_emitter')
class DetectionState(IntFlag):
Tentative = 1 # state before n_init (see DeepsortTrack)
Confirmed = 2 # after tentative
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
@classmethod
def from_deepsort_track(cls, track: DeepsortTrack):
if track.state == DeepsortTrackState.Tentative:
return cls.Tentative
if track.state == DeepsortTrackState.Confirmed:
if track.time_since_update > 0:
return cls.Lost
return cls.Confirmed
raise RuntimeError("Should not run into Deleted entries here")
@dataclass @dataclass
class Detection: class Detection:
@ -41,27 +24,13 @@ class Detection:
w: int # width - image space w: int # width - image space
h: int # height - image space h: int # height - image space
conf: float # object detector probablity conf: float # object detector probablity
state: DetectionState
def get_foot_coords(self): def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h] return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod @classmethod
def from_deepsort(cls, dstrack: DeepsortTrack): def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack)) return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
def get_scaled(self, scale: float = 1):
if scale == 1:
return self
return Detection(
self.track_id,
self.l*scale,
self.t*scale,
self.w*scale,
self.h*scale,
self.conf,
self.state)
def to_ltwh(self): def to_ltwh(self):
return (int(self.l), int(self.t), int(self.w), int(self.h)) return (int(self.l), int(self.t), int(self.w), int(self.h))
@ -70,7 +39,6 @@ class Detection:
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h)) return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
@dataclass @dataclass
class Track: class Track:
"""A bit of an haphazardous wrapper around the 'real' tracker to provide """A bit of an haphazardous wrapper around the 'real' tracker to provide
@ -95,7 +63,6 @@ class Track:
return [{"x":c[0], "y":c[1]} for c in coords] return [{"x":c[0], "y":c[1]} for c in coords]
@dataclass @dataclass
class Frame: class Frame:
index: int index: int
@ -104,19 +71,6 @@ class Frame:
tracks: Optional[dict[str, Track]] = None tracks: Optional[dict[str, Track]] = None
H: Optional[np.array] = None H: Optional[np.array] = None
def aslist(self) -> [dict]:
return { t.track_id:
{
'id': t.track_id,
'history': t.get_projected_history(self.H).tolist(),
'det_conf': t.history[-1].conf,
# 'det_conf': trajectory_data[node.id]['det_conf'],
# 'bbox': trajectory_data[node.id]['bbox'],
# 'history': history.tolist(),
'predictions': t.predictions
} for t in self.tracks.values()
}
class FrameEmitter: class FrameEmitter:
''' '''
Emit frame in a separate threat so they can be throttled, Emit frame in a separate threat so they can be throttled,
@ -141,29 +95,15 @@ class FrameEmitter:
def emit_video(self): def emit_video(self):
i = 0
for video_path in self.video_srcs: for video_path in self.video_srcs:
logger.info(f"Play from '{str(video_path)}'") logger.info(f"Play from '{str(video_path)}'")
video = cv2.VideoCapture(str(video_path)) video = cv2.VideoCapture(str(video_path))
fps = video.get(cv2.CAP_PROP_FPS) fps = video.get(cv2.CAP_PROP_FPS)
target_frame_duration = 1./fps frame_duration = 1./fps
logger.info(f"Emit frames at {fps} fps") logger.info(f"Emit frames at {fps} fps")
if '-' in video_path.stem:
path_stem = video_path.stem[:video_path.stem.rfind('-')]
else:
path_stem = video_path.stem
path_stem += "-homography"
homography_path = video_path.with_stem(path_stem).with_suffix('.txt')
logger.info(f'check homography file {homography_path}')
if homography_path.exists():
logger.info(f'Found custom homography file! Using {homography_path}')
video_H = np.loadtxt(homography_path, delimiter=',')
else:
video_H = None
prev_time = time.time() prev_time = time.time()
i = 0
while self.is_running.is_set(): while self.is_running.is_set():
ret, img = video.read() ret, img = video.read()
@ -180,19 +120,19 @@ class FrameEmitter:
# hack to mask out area # hack to mask out area
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1) cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
frame = Frame(index=i, img=img, H=video_H) frame = Frame(index=i, img=img)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing Array? # perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))
# defer next loop # defer next loop
now = time.time() new_frame_time = time.time()
time_diff = (now - prev_time) time_diff = (new_frame_time - prev_time)
if time_diff < target_frame_duration: if time_diff < frame_duration:
time.sleep(target_frame_duration - time_diff) time.sleep(frame_duration - time_diff)
now += target_frame_duration - time_diff new_frame_time += frame_duration - time_diff
else:
prev_time = now prev_time = new_frame_time
i += 1 i += 1

View file

@ -243,7 +243,7 @@ class PredictionServer:
if self.config.predict_training_data: if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else: else:
zmq_ev = self.trajectory_socket.poll(timeout=2000) zmq_ev = self.trajectory_socket.poll(timeout=3)
if not zmq_ev: if not zmq_ev:
# on no data loop so that is_running is checked # on no data loop so that is_running is checked
continue continue
@ -252,7 +252,7 @@ class PredictionServer:
frame: Frame = pickle.loads(data) frame: Frame = pickle.loads(data)
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()} # trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data) # trajectory_data = json.loads(data)
# logger.debug(f"Receive {frame.index}") logger.debug(f"Receive {frame.index}")
# class FakeNode: # class FakeNode:
# def __init__(self, node_type: NodeType): # def __init__(self, node_type: NodeType):
@ -276,12 +276,12 @@ class PredictionServer:
ax = derivative_of(vx, 0.1) ax = derivative_of(vx, 0.1)
ay = derivative_of(vy, 0.1) ay = derivative_of(vy, 0.1)
data_dict = {('position', 'x'): x[:], # [-10:-1] data_dict = {('position', 'x'): x[:],
('position', 'y'): y[:], # [-10:-1] ('position', 'y'): y[:],
('velocity', 'x'): vx[:], # [-10:-1] ('velocity', 'x'): vx[:],
('velocity', 'y'): vy[:], # [-10:-1] ('velocity', 'y'): vy[:],
('acceleration', 'x'): ax[:], # [-10:-1] ('acceleration', 'x'): ax[:],
('acceleration', 'y'): ay[:]} # [-10:-1] ('acceleration', 'y'): ay[:]}
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']]) data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
node_data = pd.DataFrame(data_dict, columns=data_columns) node_data = pd.DataFrame(data_dict, columns=data_columns)
@ -301,7 +301,7 @@ class PredictionServer:
# TODO: we want to send out empty result... # TODO: we want to send out empty result...
# And want to update the network # And want to update the network
# data = json.dumps({}) data = json.dumps({})
self.prediction_socket.send_pyobj(frame) self.prediction_socket.send_pyobj(frame)
continue continue
@ -325,7 +325,7 @@ class PredictionServer:
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=125, # TODO: make variable prediction_horizon=25, # TODO: make variable
num_samples=5, # TODO: make variable num_samples=5, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)

View file

@ -1,4 +1,3 @@
import time
import ffmpeg import ffmpeg
from argparse import Namespace from argparse import Namespace
import datetime import datetime
@ -9,7 +8,7 @@ import numpy as np
import zmq import zmq
from trap.frame_emitter import DetectionState, Frame from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer") logger = logging.getLogger("trap.renderer")
@ -85,7 +84,7 @@ class Renderer:
while self.is_running.is_set(): while self.is_running.is_set():
i+=1 i+=1
zmq_ev = self.frame_sock.poll(timeout=2000) zmq_ev = self.frame_sock.poll(timeout=3)
if not zmq_ev: if not zmq_ev:
# when no data comes in, loop so that is_running is checked # when no data comes in, loop so that is_running is checked
continue continue
@ -96,32 +95,6 @@ class Renderer:
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.debug(f'reuse prediction') logger.debug(f'reuse prediction')
if first_time is None:
first_time = frame.time
decorate_frame(frame, prediction_frame, first_time)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.debug(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.array:
img = frame.img img = frame.img
# all not working: # all not working:
@ -135,31 +108,27 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.arra
if not prediction_frame: if not prediction_frame:
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
# continue continue
else: else:
inv_H = np.linalg.pinv(prediction_frame.H)
for track_id, track in prediction_frame.tracks.items(): for track_id, track in prediction_frame.tracks.items():
if not len(track.history): if not len(track.history):
continue continue
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0] # coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
coords = [d.get_foot_coords() for d in track.history] coords = [d.get_foot_coords() for d in track.history]
confirmations = [d.state == DetectionState.Confirmed for d in track.history]
# logger.warning(f"{coords=}") # logger.warning(f"{coords=}")
for ci in range(1, len(coords)): for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]] start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]] end = [int(p) for p in coords[ci]]
color = (255,255,255) if confirmations[ci] else (100,100,100) cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
cv2.line(img, start, end, color, 2, lineType=cv2.LINE_AA)
if not track.predictions or not len(track.predictions): if not track.predictions or not len(track.predictions):
continue continue
for pred_i, pred in enumerate(track.predictions): for pred_i, pred in enumerate(track.predictions):
pred_coords = cv2.perspectiveTransform(np.array([pred]), inv_H)[0] pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
color = (0,0,255) if pred_i else (100,100,100) color = (0,0,255) if pred_i == 1 else (100,100,100)
for ci in range(1, len(pred_coords)): for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]] start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]] end = [int(p) for p in pred_coords[ci]]
@ -179,20 +148,36 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.arra
cv2.rectangle(img, p1, p2, (255,0,0), 1) cv2.rectangle(img, p1, p2, (255,0,0), 1)
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA) cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
if prediction_frame: if prediction_frame:
# render Δt and Δ frames
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1) cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()])}", (580, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()])}", (660, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()])}", (740, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
return img
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.info(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def run_renderer(config: Namespace, is_running: Event): def run_renderer(config: Namespace, is_running: Event):

View file

@ -1,9 +1,7 @@
from argparse import Namespace from argparse import Namespace
import asyncio import asyncio
import dataclasses
import errno import errno
import json
import logging import logging
from multiprocessing import Event from multiprocessing import Event
import subprocess import subprocess
@ -17,8 +15,6 @@ import tornado.websocket
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.forwarder") logger = logging.getLogger("trap.forwarder")
@ -28,7 +24,7 @@ class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
self.zmq_socket = zmq_socket self.zmq_socket = zmq_socket
async def on_message(self, message): async def on_message(self, message):
logger.debug(f"receive msg") logger.debug(f"recieve msg")
try: try:
await self.zmq_socket.send_string(message) await self.zmq_socket.send_string(message)
@ -116,13 +112,11 @@ class WsRouter:
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
logger.info(f'Publish trajectories on {config.zmq_trajectory_addr}') self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB) self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'') self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.application = tornado.web.Application( self.application = tornado.web.Application(
[ [
@ -172,16 +166,11 @@ class WsRouter:
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while self.is_running.is_set(): while self.is_running.is_set():
# timeout so that if no events occur, loop can still stop on is_running # timeout so that if no events occur, loop can still stop on is_running
has_event = await self.prediction_socket.poll(timeout=1000) has_event = await self.prediction_socket.poll(timeout=1)
if has_event: if has_event:
try: msg = await self.prediction_socket.recv_string()
frame: Frame = await self.prediction_socket.recv_pyobj()
# tacks = [dataclasses.asdict(h) for t in frame.tracks.values() for t.history in t]
msg = json.dumps(frame.aslist())
logger.debug(f"Forward prediction message of {len(msg)} chars") logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg) WebSocketPredictionHandler.write_to_clients(msg)
except Exception as e:
logger.exception(e)
# die together: # die together:
self.evt_loop.stop() self.evt_loop.stop()

View file

@ -22,7 +22,7 @@ from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult from ultralytics.engine.results import Results as YOLOResult
from trap.frame_emitter import DetectionState, Frame, Detection, Track from trap.frame_emitter import Frame, Detection, Track
# Detection = [int, int, int, int, float, int] # Detection = [int, int, int, int, float, int]
# Detections = [Detection] # Detections = [Detection]
@ -66,9 +66,6 @@ class Tracker:
# TODO: support removal # TODO: support removal
self.tracks = defaultdict(lambda: Track()) self.tracks = defaultdict(lambda: Track())
logger.debug(f"Load tracker: {self.config.detector}")
if self.config.detector == DETECTOR_RETINANET: if self.config.detector == DETECTOR_RETINANET:
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT # weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2) # self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
@ -79,7 +76,7 @@ class Tracker:
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9, self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
) )
elif self.config.detector == DETECTOR_MASKRCNN: elif self.config.detector == DETECTOR_MASKRCNN:
@ -90,7 +87,7 @@ class Tracker:
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9, self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
) )
elif self.config.detector == DETECTOR_YOLOv8: elif self.config.detector == DETECTOR_YOLOv8:
@ -123,7 +120,7 @@ class Tracker:
logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.") logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
training_fp = open(self.config.save_for_training / 'all.txt', 'w') training_fp = open(self.config.save_for_training / 'all.txt', 'w')
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE) training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
prev_frame_i = -1 prev_frame_i = -1
@ -136,12 +133,6 @@ class Tracker:
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) # time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
# prev_run_time = time.time() # prev_run_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warn('skip poll after 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
start_time = time.time() start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
@ -151,9 +142,6 @@ class Tracker:
prev_frame_i = frame.index prev_frame_i = frame.index
# load homography into frame (TODO: should this be done in emitter?) # load homography into frame (TODO: should this be done in emitter?)
if frame.H is None:
# logger.warning('Falling back to default H')
# fallback: load configured H
frame.H = self.H frame.H = self.H
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s") # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
@ -162,7 +150,7 @@ class Tracker:
if self.config.detector == DETECTOR_YOLOv8: if self.config.detector == DETECTOR_YOLOv8:
detections: [Detection] = self._yolov8_track(frame.img) detections: [Detection] = self._yolov8_track(frame.img)
else : else :
detections: [Detection] = self._resnet_track(frame.img, scale = 1) detections: [Detection] = self._resnet_track(frame.img)
# Store detections into tracklets # Store detections into tracklets
@ -213,18 +201,10 @@ class Tracker:
if training_csv: if training_csv:
training_csv.writerows([{ training_csv.writerows([{
'frame_id': round(frame.index * 10., 1), # not really time 'frame_id': round(frame.index * 10., 1), # not really time
'track_id': t.track_id, 'track_id': t['id'],
'l': t.history[-1].l, 'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
't': t.history[-1].t, 'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
'w': t.history[-1].w, } for t in active_tracks.values()])
'h': t.history[-1].h,
'x': t.get_projected_history(frame.H)[-1][0],
'y': t.get_projected_history(frame.H)[-1][1],
'state': t.history[-1].state.value
# only keep _actual_detections, no lost entries
} for t in active_tracks.values()
# if t.history[-1].state != DetectionState.Lost
])
training_frames += len(active_tracks) training_frames += len(active_tracks)
# print(time.time() - start_time) # print(time.time() - start_time)
@ -256,13 +236,10 @@ class Tracker:
return [] return []
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())] return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
def _resnet_track(self, img, scale: float = 1) -> [Detection]: def _resnet_track(self, img) -> [Detection]:
if scale != 1:
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
img = cv2.resize(img, dsize)
detections = self._resnet_detect_persons(img) detections = self._resnet_detect_persons(img)
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img) tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
return [Detection.from_deepsort(t).get_scaled(1/scale) for t in tracks] return [Detection.from_deepsort(t) for t in tracks]
def _resnet_detect_persons(self, frame) -> [Detection]: def _resnet_detect_persons(self, frame) -> [Detection]:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

View file

@ -25,13 +25,12 @@
<script> <script>
// minified https://github.com/joewalnes/reconnecting-websocket // minified https://github.com/joewalnes/reconnecting-websocket
!function (a, b) { "function" == typeof define && define.amd ? define([], b) : "undefined" != typeof module && module.exports ? module.exports = b() : a.ReconnectingWebSocket = b() }(this, function () { function a(b, c, d) { function l(a, b) { var c = document.createEvent("CustomEvent"); return c.initCustomEvent(a, !1, !1, b), c } var e = { debug: !1, automaticOpen: !0, reconnectInterval: 1e3, maxReconnectInterval: 3e4, reconnectDecay: 1.5, timeoutInterval: 2e3 }; d || (d = {}); for (var f in e) this[f] = "undefined" != typeof d[f] ? d[f] : e[f]; this.url = b, this.reconnectAttempts = 0, this.readyState = WebSocket.CONNECTING, this.protocol = null; var h, g = this, i = !1, j = !1, k = document.createElement("div"); k.addEventListener("open", function (a) { g.onopen(a) }), k.addEventListener("close", function (a) { g.onclose(a) }), k.addEventListener("connecting", function (a) { g.onconnecting(a) }), k.addEventListener("message", function (a) { g.onmessage(a) }), k.addEventListener("error", function (a) { g.onerror(a) }), this.addEventListener = k.addEventListener.bind(k), this.removeEventListener = k.removeEventListener.bind(k), this.dispatchEvent = k.dispatchEvent.bind(k), this.open = function (b) { h = new WebSocket(g.url, c || []), b || k.dispatchEvent(l("connecting")), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "attempt-connect", g.url); var d = h, e = setTimeout(function () { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "connection-timeout", g.url), j = !0, d.close(), j = !1 }, g.timeoutInterval); h.onopen = function () { clearTimeout(e), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onopen", g.url), g.protocol = h.protocol, g.readyState = WebSocket.OPEN, g.reconnectAttempts = 0; var d = l("open"); d.isReconnect = b, b = !1, k.dispatchEvent(d) }, h.onclose = function (c) { if (clearTimeout(e), h = null, i) g.readyState = WebSocket.CLOSED, k.dispatchEvent(l("close")); else { g.readyState = WebSocket.CONNECTING; var d = l("connecting"); d.code = c.code, d.reason = c.reason, d.wasClean = c.wasClean, k.dispatchEvent(d), b || j || ((g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onclose", g.url), k.dispatchEvent(l("close"))); var e = g.reconnectInterval * Math.pow(g.reconnectDecay, g.reconnectAttempts); setTimeout(function () { g.reconnectAttempts++, g.open(!0) }, e > g.maxReconnectInterval ? g.maxReconnectInterval : e) } }, h.onmessage = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onmessage", g.url, b.data); var c = l("message"); c.data = b.data, k.dispatchEvent(c) }, h.onerror = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onerror", g.url, b), k.dispatchEvent(l("error")) } }, 1 == this.automaticOpen && this.open(!1), this.send = function (b) { if (h) return (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "send", g.url, b), h.send(b); throw "INVALID_STATE_ERR : Pausing to reconnect websocket" }, this.close = function (a, b) { "undefined" == typeof a && (a = 1e3), i = !0, h && h.close(a, b) }, this.refresh = function () { h && h.close() } } return a.prototype.onopen = function () { }, a.prototype.onclose = function () { }, a.prototype.onconnecting = function () { }, a.prototype.onmessage = function () { }, a.prototype.onerror = function () { }, a.debugAll = !1, a.CONNECTING = WebSocket.CONNECTING, a.OPEN = WebSocket.OPEN, a.CLOSING = WebSocket.CLOSING, a.CLOSED = WebSocket.CLOSED, a }); !function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script> </script>
<script> <script>
// map the field to coordinates of our dummy tracker // map the field to coordinates of our dummy tracker
// see test_homography.ipynb for the logic behind these values const field_range = { x: [-30, 10], y: [-10, 10] }
const field_range = { x: [-13.092, 15.37], y: [-4.66, 10.624] }
// Create WebSocket connection. // Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`); const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
@ -97,17 +96,17 @@
let current_pos = null; let current_pos = null;
function appendAndSendPositions() { function appendAndSendPositions(){
if (is_moving && current_pos !== null) { if(is_moving && current_pos!==null){
// throttled update of tracker on movement // throttled update of tracker on movement
tracker[person_counter].addToHistory(current_pos); tracker[person_counter].addToHistory(current_pos);
} }
for (const person_id in tracker) { for(const person_id in tracker){
if (person_id != person_counter) { // compare int/str if(person_id != person_counter){ // compare int/str
// fade out old tracks // fade out old tracks
tracker[person_id].history.shift() tracker[person_id].history.shift()
if (!tracker[person_id].history.length) { if(!tracker[person_id].history.length){
delete tracker[person_id] delete tracker[person_id]
} }
} }
@ -126,7 +125,7 @@
const mousePos = getMousePos(fieldEl, event); const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos) const position = mouse_coordinates_to_position(mousePos)
current_pos = position; current_pos = position;
tracker[person_counter].addToHistory(current_pos); // tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker)) // trajectory_socket.send(JSON.stringify(tracker))
}); });
@ -135,8 +134,8 @@
const mousePos = getMousePos(fieldEl, event); const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos) const position = mouse_coordinates_to_position(mousePos)
current_pos = position; current_pos = position;
tracker[person_counter].addToHistory(current_pos); // tracker[person_counter].addToHistory(current_pos);
trajectory_socket.send(JSON.stringify(tracker)) // trajectory_socket.send(JSON.stringify(tracker))
}); });
document.addEventListener('mouseup', (e) => { document.addEventListener('mouseup', (e) => {
person_counter++; person_counter++;
@ -171,13 +170,12 @@
ctx.stroke(); ctx.stroke();
} }
if (person.hasOwnProperty('predictions') && person.predictions.length > 0) { if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled // multiple predictions can be sampled
person.predictions.forEach((prediction, i) => { person.predictions.forEach((prediction, i) => {
ctx.beginPath() ctx.beginPath()
ctx.lineWidth = 0.2; ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa"; ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
ctx.strokeStyle = "#ccaaaa";
// start from current position: // start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1]))); ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
@ -186,33 +184,6 @@
} }
ctx.stroke(); ctx.stroke();
}); });
// average stroke:
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#ff0000";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (let index = 0; index < person.predictions[0].length; index++) {
sum = person.predictions.reduce(
(accumulator, prediction) => ({
"x": accumulator.x + prediction[index][0],
"y": accumulator.y + prediction[index][1],
}),
{ x: 0, y: 0 },
);
avg = { x: sum.x / person.predictions.length, y: sum.y / person.predictions.length }
// console.log(sum, avg)
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(avg)))
}
// for (const position of ) {
// }
ctx.stroke();
} }
} }
ctx.restore(); ctx.restore();