""" This project produces "detections" for DeepSORT in an attempt to trick the algorithm into moonwalking over a crowd. The framerate of rendering and detection can be distinct. Also, all parameters (incl. framerate) can change along the way, thus positions cannot depend on that. """ from __future__ import annotations from dataclasses import dataclass import dataclasses from typing import Optional import time import pyglet import logging import numpy as np # from deep_sort_realtime.deepsort_tracker import DeepSort from sort import Sort from collections import defaultdict Interval = float # seconds logger = logging.getLogger('moonwalk') @dataclass class Params: visual_variability: float = 0 video_fps: float = 25 tracker_fps: float = 25 # iou = None emitter_speed: float = 1 # objects per second object_velocity: float = 40 # pixels/second velocity_decay: float = 1 # make system a bit springy class DetectedObject: def __init__(self, canvas: Canvas): self.canvas = canvas # TODO handle variability self.v = self.canvas.params.object_velocity self.t = 40 # top self.l = 0 # left self.w = 10 # width self.h = 20 # height self.shape = pyglet.shapes.Rectangle(self.l, self.t, self.w, self.h, color=(255, 22, 20), batch=self.canvas.batch_figures) # rectangle.opacity = 128 # rectangle.rotation = 33 #TODO renderer def update(self, dt: Interval): """ Update position """ self.l += dt * self.canvas.params.object_velocity self.shape.x = self.l # TODO exponential decay with self.params.velocity_decay class ObjectEmitter: """ Emit detectable objects """ def __init__(self, params: Params, canvas: Canvas): self.lastEmit = 0 self.params = params self.canvas = canvas def emit(self, dt: Interval) -> list[DetectedObject]: self.lastEmit += dt if self.lastEmit is None or self.lastEmit >= 1/self.params.emitter_speed: logger.info('emit!') obj = DetectedObject(self.canvas) self.lastEmit = 0 return [obj] return [] class Canvas: """ A canvas with moving objects """ def __init__(self, params: Params): self.width = 1280 self.height = 720 self.objects: list[DetectedObject] = [] self.lastSnapshot: Optional[float] = None self.params = params self.emitter = ObjectEmitter(self.params, self) self.hide_stats = False config = pyglet.gl.Config(sample_buffers=1, samples=4, double_buffer=True) # , fullscreen=self.config.render_window self.window = pyglet.window.Window(width=self.width, height=self.height, config=config, fullscreen=False) self.window.set_handler('on_draw', self.on_draw) self.window.set_handler('on_key_press', self.on_key_press) self.window.set_handler('on_mouse_scroll', self.on_mouse_scroll) self.window.set_handler('on_refresh', self.on_refresh) # self.window.set_handler('on_refresh', self.on_refresh) # self.window.set_handler('on_close', self.on_close) # Purple background color: # pyglet.gl.glClearColor(*AnimConfig.clear_color) self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=(255,255,255,255)) self.fps_display.label.x = self.window.width - 150 self.fps_display.label.y = self.window.height - 17 self.fps_display.label.bold = False self.fps_display.label.font_size = 10 self.label_time = pyglet.text.Label("t", x=20, y=self.height - 17, color=(255,255,255,255)) self.batch_figures = pyglet.graphics.Batch() self.batch_bounding_boxes = pyglet.graphics.Batch() self.batch_info = pyglet.graphics.Batch() self.tracks = [] self.labels = { 'objects': pyglet.text.Label("", x=20, y=30, color=(255,255,255,255), batch=self.batch_info), 'tracks': pyglet.text.Label("", x=120, y=30, color=(255,255,255,255), batch=self.batch_info), } for i, field in enumerate(dataclasses.fields(self.params)): self.labels[field.name] = pyglet.text.Label(f"{field.name}: {field.default}", x=20, y=30 + 15*(i+1), color=(255,255,255,255), batch=self.batch_info) self.track_shapes = defaultdict(lambda: pyglet.shapes.Box(0,0,0,0,color=(0,255,0),thickness=2, batch=self.batch_bounding_boxes)) self.tracker = Sort(max_age=5, min_hits=2, iou_threshold=0) #DeepSort(max_age=5) pyglet.clock.schedule_interval(self.on_track, 1/self.params.tracker_fps) self.interval_items: list[pyglet.clock._ScheduledIntervalItem] = [i for i in pyglet.clock._default._schedule_interval_items if i.func == self.on_track] def run(self): self.event_loop = pyglet.app.EventLoop() # pyglet.clock.schedule_interval(self.check_running, 0.1) # pyglet.clock.schedule(self.check_frames) # pyglet.clock.schedule(self.track) self.event_loop.run() def on_draw(self): # print(time.monotonic()) self.label_time.text = f"{time.monotonic()}" self.window.clear() self.batch_figures.draw() self.batch_bounding_boxes.draw() self.batch_info.draw() self.fps_display.draw() self.label_time.draw() def on_close(self): logger.info('closing') pass def on_key_press(self, symbol, modifiers): if symbol == pyglet.window.key.Q: self.window.close() exit() if symbol == pyglet.window.key.V: level = logging.INFO if logger.getEffectiveLevel() == logging.DEBUG else logging.DEBUG logger.setLevel(level) logger.info(f"set log level: {level}") if symbol == pyglet.window.key.UP: logger.debug('up') self.params.object_velocity += (10 if pyglet.window.key.MOD_SHIFT & modifiers else 1) if symbol == pyglet.window.key.DOWN: logger.debug('down') self.params.object_velocity -= (10 if pyglet.window.key.MOD_SHIFT & modifiers else 1) def on_mouse_scroll(self, x, y, scroll_x, scroll_y): # determine xy position to select var to change, # then change according to scroll_y for param_name, param_value in dataclasses.asdict(self.params).items(): if x >= self.labels[param_name].x and \ x <= (self.labels[param_name].x + self.labels[param_name].content_width) and \ y >= self.labels[param_name].y and \ y <= (self.labels[param_name].y + self.labels[param_name].content_height): setattr(self.params, param_name, param_value+scroll_y) def on_track(self, dt): # bbs = object_detector.detect(frame) objects = self.snapshot() # TODO chipper & embedder # bbs = [([o.l, o.t, o.w, o.h], 1, 1) for o in objects] # DEEP SORT: self.tracks = self.tracker.update_tracks(bbs, frame=np.zeros([1280,720])) # bbs expected to be a list of detections, each in tuples of ( [left,top,w,h], confidence, detection_class ) bbs = np.array([[o.l, o.t, o.l+ o.w, o.t+o.h, 1, 1] for o in objects]) self.tracks = self.tracker.update(bbs) # a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] # self.tracks is a np array where each row contains a valid bounding box and track_id (last column) # remove old shapes ids = [track[4] for track in self.tracks] for k in list(self.track_shapes.keys()): if k not in ids: self.track_shapes.pop(k) logger.debug(f"shape removed {k}" ) # print([t[4] for t in self.tracks]) def set_tracker_fps(self, fps: float): self.params.tracker_fps = fps for interval in self.interval_items: interval.interval = 1/fps def prune(self): """ Loop over objects remove those out of the frame """ for i, object in enumerate(self.objects.copy()): if object.l > self.width: logging.info(f'Delete {i}') self.objects.pop(i) def snapshot(self) -> list[DetectedObject]: """ Update all object positions base on dt = now - lastSnapshot """ now = time.monotonic() if self.lastSnapshot is None: self.lastSnapshot = now dt = now - self.lastSnapshot self.objects.extend(self.emitter.emit(dt)) for object in self.objects: object.update(dt) self.prune() self.lastSnapshot = now return self.objects def on_refresh(self, dt: float): objects = self.snapshot() self.labels['objects'].text = f"Objects: {len(objects)}" self.labels['tracks'].text = f"Tracks: {len(self.tracks)}" # self.labels['velocity'].text = f"Velocity: {self.params.object_velocity}" # self.labels['tracker_fps'].text = f"Tracker FPS: {self.params.tracker_fps}" for name, value in dataclasses.asdict(self.params).items(): self.labels[name].text = f"{name}: {value}" for track in self.tracks: nr = track[4] self.track_shapes[nr].x = track[0] self.track_shapes[nr].y = track[1] self.track_shapes[nr].width = track[2] - track[0] self.track_shapes[nr].height = track[3] - track[1] # TODO: shape in DetectedObject # rectangle = shapes.Rectangle(250, 300, 400, 200, color=(255, 22, 20), batch=batch) # rectangle.opacity = 128 # rectangle.rotation = 33 # print(objects) # id(objects) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) params = Params() canvas = Canvas(params) canvas.run()