283 lines
No EOL
9.9 KiB
Python
283 lines
No EOL
9.9 KiB
Python
"""
|
|
This project produces "detections" for DeepSORT
|
|
in an attempt to trick the algorithm into moonwalking
|
|
over a crowd.
|
|
|
|
The framerate of rendering and detection can be distinct.
|
|
Also, all parameters (incl. framerate) can change along the way,
|
|
thus positions cannot depend on that.
|
|
"""
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
import dataclasses
|
|
from typing import Optional
|
|
import time
|
|
import pyglet
|
|
import logging
|
|
import numpy as np
|
|
# from deep_sort_realtime.deepsort_tracker import DeepSort
|
|
from sort import Sort
|
|
from collections import defaultdict
|
|
|
|
|
|
Interval = float # seconds
|
|
|
|
logger = logging.getLogger('moonwalk')
|
|
|
|
|
|
@dataclass
|
|
class Params:
|
|
visual_variability: float = 0
|
|
video_fps: float = 25
|
|
tracker_fps: float = 25
|
|
# iou = None
|
|
emitter_speed: float = 1 # objects per second
|
|
object_velocity: float = 40 # pixels/second
|
|
velocity_decay: float = 1 # make system a bit springy
|
|
|
|
|
|
|
|
class DetectedObject:
|
|
def __init__(self, canvas: Canvas):
|
|
self.canvas = canvas
|
|
# TODO handle variability
|
|
self.v = self.canvas.params.object_velocity
|
|
self.t = 40 # top
|
|
self.l = 0 # left
|
|
self.w = 10 # width
|
|
self.h = 20 # height
|
|
|
|
self.shape = pyglet.shapes.Rectangle(self.l, self.t, self.w, self.h, color=(255, 22, 20), batch=self.canvas.batch_figures)
|
|
# rectangle.opacity = 128
|
|
# rectangle.rotation = 33
|
|
#TODO renderer
|
|
|
|
def update(self, dt: Interval):
|
|
"""
|
|
Update position
|
|
"""
|
|
self.l += dt * self.canvas.params.object_velocity
|
|
self.shape.x = self.l
|
|
# TODO exponential decay with self.params.velocity_decay
|
|
|
|
class ObjectEmitter:
|
|
"""
|
|
Emit detectable objects
|
|
"""
|
|
def __init__(self, params: Params, canvas: Canvas):
|
|
self.lastEmit = 0
|
|
self.params = params
|
|
self.canvas = canvas
|
|
|
|
def emit(self, dt: Interval) -> list[DetectedObject]:
|
|
self.lastEmit += dt
|
|
if self.lastEmit is None or self.lastEmit >= 1/self.params.emitter_speed:
|
|
logger.info('emit!')
|
|
obj = DetectedObject(self.canvas)
|
|
self.lastEmit = 0
|
|
return [obj]
|
|
return []
|
|
|
|
|
|
class Canvas:
|
|
"""
|
|
A canvas with moving objects
|
|
"""
|
|
def __init__(self, params: Params):
|
|
self.width = 1280
|
|
self.height = 720
|
|
self.objects: list[DetectedObject] = []
|
|
self.lastSnapshot: Optional[float] = None
|
|
self.params = params
|
|
self.emitter = ObjectEmitter(self.params, self)
|
|
|
|
|
|
self.hide_stats = False
|
|
|
|
config = pyglet.gl.Config(sample_buffers=1, samples=4, double_buffer=True)
|
|
# , fullscreen=self.config.render_window
|
|
self.window = pyglet.window.Window(width=self.width, height=self.height, config=config, fullscreen=False)
|
|
self.window.set_handler('on_draw', self.on_draw)
|
|
self.window.set_handler('on_key_press', self.on_key_press)
|
|
self.window.set_handler('on_mouse_scroll', self.on_mouse_scroll)
|
|
|
|
self.window.set_handler('on_refresh', self.on_refresh)
|
|
|
|
# self.window.set_handler('on_refresh', self.on_refresh)
|
|
# self.window.set_handler('on_close', self.on_close)
|
|
|
|
# Purple background color:
|
|
# pyglet.gl.glClearColor(*AnimConfig.clear_color)
|
|
self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=(255,255,255,255))
|
|
self.fps_display.label.x = self.window.width - 150
|
|
self.fps_display.label.y = self.window.height - 17
|
|
self.fps_display.label.bold = False
|
|
self.fps_display.label.font_size = 10
|
|
|
|
self.label_time = pyglet.text.Label("t", x=20, y=self.height - 17, color=(255,255,255,255))
|
|
|
|
self.batch_figures = pyglet.graphics.Batch()
|
|
self.batch_bounding_boxes = pyglet.graphics.Batch()
|
|
self.batch_info = pyglet.graphics.Batch()
|
|
|
|
self.tracks = []
|
|
|
|
self.labels = {
|
|
'objects': pyglet.text.Label("", x=20, y=30, color=(255,255,255,255), batch=self.batch_info),
|
|
'tracks': pyglet.text.Label("", x=120, y=30, color=(255,255,255,255), batch=self.batch_info),
|
|
}
|
|
for i, field in enumerate(dataclasses.fields(self.params)):
|
|
self.labels[field.name] = pyglet.text.Label(f"{field.name}: {field.default}", x=20, y=30 + 15*(i+1), color=(255,255,255,255), batch=self.batch_info)
|
|
|
|
self.track_shapes = defaultdict(lambda: pyglet.shapes.Box(0,0,0,0,color=(0,255,0),thickness=2, batch=self.batch_bounding_boxes))
|
|
|
|
self.tracker = Sort(max_age=5, min_hits=2, iou_threshold=0) #DeepSort(max_age=5)
|
|
|
|
pyglet.clock.schedule_interval(self.on_track, 1/self.params.tracker_fps)
|
|
|
|
self.interval_items: list[pyglet.clock._ScheduledIntervalItem] = [i for i in pyglet.clock._default._schedule_interval_items if i.func == self.on_track]
|
|
|
|
|
|
def run(self):
|
|
|
|
self.event_loop = pyglet.app.EventLoop()
|
|
# pyglet.clock.schedule_interval(self.check_running, 0.1)
|
|
# pyglet.clock.schedule(self.check_frames)
|
|
# pyglet.clock.schedule(self.track)
|
|
self.event_loop.run()
|
|
|
|
def on_draw(self):
|
|
# print(time.monotonic())
|
|
self.label_time.text = f"{time.monotonic()}"
|
|
self.window.clear()
|
|
self.batch_figures.draw()
|
|
self.batch_bounding_boxes.draw()
|
|
self.batch_info.draw()
|
|
self.fps_display.draw()
|
|
self.label_time.draw()
|
|
|
|
def on_close(self):
|
|
logger.info('closing')
|
|
pass
|
|
|
|
|
|
def on_key_press(self, symbol, modifiers):
|
|
if symbol == pyglet.window.key.Q:
|
|
self.window.close()
|
|
exit()
|
|
|
|
if symbol == pyglet.window.key.V:
|
|
level = logging.INFO if logger.getEffectiveLevel() == logging.DEBUG else logging.DEBUG
|
|
logger.setLevel(level)
|
|
logger.info(f"set log level: {level}")
|
|
if symbol == pyglet.window.key.UP:
|
|
logger.debug('up')
|
|
self.params.object_velocity += (10 if pyglet.window.key.MOD_SHIFT & modifiers else 1)
|
|
if symbol == pyglet.window.key.DOWN:
|
|
logger.debug('down')
|
|
self.params.object_velocity -= (10 if pyglet.window.key.MOD_SHIFT & modifiers else 1)
|
|
|
|
def on_mouse_scroll(self, x, y, scroll_x, scroll_y):
|
|
# determine xy position to select var to change,
|
|
# then change according to scroll_y
|
|
for param_name, param_value in dataclasses.asdict(self.params).items():
|
|
if x >= self.labels[param_name].x and \
|
|
x <= (self.labels[param_name].x + self.labels[param_name].content_width) and \
|
|
y >= self.labels[param_name].y and \
|
|
y <= (self.labels[param_name].y + self.labels[param_name].content_height):
|
|
setattr(self.params, param_name, param_value+scroll_y)
|
|
|
|
|
|
def on_track(self, dt):
|
|
# bbs = object_detector.detect(frame)
|
|
objects = self.snapshot()
|
|
# TODO chipper & embedder
|
|
# bbs = [([o.l, o.t, o.w, o.h], 1, 1) for o in objects]
|
|
# DEEP SORT: self.tracks = self.tracker.update_tracks(bbs, frame=np.zeros([1280,720])) # bbs expected to be a list of detections, each in tuples of ( [left,top,w,h], confidence, detection_class )
|
|
|
|
bbs = np.array([[o.l, o.t, o.l+ o.w, o.t+o.h, 1, 1] for o in objects])
|
|
self.tracks = self.tracker.update(bbs) # a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
|
|
|
|
# self.tracks is a np array where each row contains a valid bounding box and track_id (last column)
|
|
|
|
# remove old shapes
|
|
ids = [track[4] for track in self.tracks]
|
|
for k in list(self.track_shapes.keys()):
|
|
if k not in ids:
|
|
self.track_shapes.pop(k)
|
|
logger.debug(f"shape removed {k}" )
|
|
# print([t[4] for t in self.tracks])
|
|
|
|
|
|
def set_tracker_fps(self, fps: float):
|
|
self.params.tracker_fps = fps
|
|
for interval in self.interval_items:
|
|
interval.interval = 1/fps
|
|
|
|
|
|
def prune(self):
|
|
"""
|
|
Loop over objects remove those out of the frame
|
|
"""
|
|
for i, object in enumerate(self.objects.copy()):
|
|
if object.l > self.width:
|
|
logging.info(f'Delete {i}')
|
|
self.objects.pop(i)
|
|
|
|
|
|
def snapshot(self) -> list[DetectedObject]:
|
|
"""
|
|
Update all object positions base on dt = now - lastSnapshot
|
|
"""
|
|
now = time.monotonic()
|
|
if self.lastSnapshot is None:
|
|
self.lastSnapshot = now
|
|
|
|
dt = now - self.lastSnapshot
|
|
|
|
self.objects.extend(self.emitter.emit(dt))
|
|
for object in self.objects:
|
|
object.update(dt)
|
|
self.prune()
|
|
|
|
self.lastSnapshot = now
|
|
return self.objects
|
|
|
|
|
|
|
|
def on_refresh(self, dt: float):
|
|
objects = self.snapshot()
|
|
|
|
self.labels['objects'].text = f"Objects: {len(objects)}"
|
|
self.labels['tracks'].text = f"Tracks: {len(self.tracks)}"
|
|
# self.labels['velocity'].text = f"Velocity: {self.params.object_velocity}"
|
|
# self.labels['tracker_fps'].text = f"Tracker FPS: {self.params.tracker_fps}"
|
|
|
|
for name, value in dataclasses.asdict(self.params).items():
|
|
self.labels[name].text = f"{name}: {value}"
|
|
|
|
for track in self.tracks:
|
|
nr = track[4]
|
|
self.track_shapes[nr].x = track[0]
|
|
self.track_shapes[nr].y = track[1]
|
|
self.track_shapes[nr].width = track[2] - track[0]
|
|
self.track_shapes[nr].height = track[3] - track[1]
|
|
|
|
# TODO: shape in DetectedObject
|
|
# rectangle = shapes.Rectangle(250, 300, 400, 200, color=(255, 22, 20), batch=batch)
|
|
# rectangle.opacity = 128
|
|
# rectangle.rotation = 33
|
|
# print(objects)
|
|
# id(objects)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
params = Params()
|
|
canvas = Canvas(params)
|
|
canvas.run()
|
|
|
|
|