trap/trap/track_history.py

186 lines
No EOL
6.1 KiB
Python

from dataclasses import dataclass
import logging
from pathlib import Path
import pickle
from threading import Lock
import time
from typing import Dict, Iterable, List, Optional, Set
import numpy as np
from trap.base import Camera, Track
from trap.lines import Coordinate
from trap.tracker import FinalDisplacementFilter, Smoother, TrackReader
from scipy.spatial import KDTree
logger = logging.getLogger('history')
@dataclass
class TrackHistoryState():
"""
The lock of TrackHistory is not pickle-able so separate it into a separate state
"""
tracks: List[Track]
track_histories: Dict[str, np.ndarray]
indexed_track_ids: List[str]
tree: KDTree
class TrackHistory():
def __init__(self, path: Path, camera: Camera, cache_path: Optional[Path]):
self.path = path
self.camera = camera
self.cache_path = cache_path
self.lock = Lock()
self.load_from_cache() or self.reload()
def load_from_cache(self):
if self.cache_path is None:
return False
if self.cache_path.exists():
logger.debug("Load history state from cache")
with self.cache_path.open('rb') as fp:
try:
state = pickle.load(fp)
if not isinstance(state, TrackHistoryState):
raise RuntimeError("Pickled data is not a trackhistorystate")
self.state = state
return True
except Exception as e:
logger.warning(f"Cannot read cache {self.cache_path}: {e}")
return False
def build_tree(self):
reader = TrackReader(self.path, self.camera.fps)
logger.debug(f'loaded {len(reader)} tracks')
track_filter = FinalDisplacementFilter(2)
tracks = track_filter.apply(reader, self.camera)
logger.debug(f'after filtering left with {len(tracks)} tracks')
tracks: List[Track] = [t.get_with_interpolated_history() for t in tracks]
logger.debug(f'interpolated {len(tracks)} tracks')
# use convolution here, because precision does not matter and it is _way_ faster
smoother = Smoother(convolution=True)
tracks = [smoother.smooth_track(t) for t in tracks]
logger.debug(f'smoothed')
tracks = {track.track_id: track for track in tracks}
track_histories = {t.track_id: t.get_projected_history(camera=self.camera) for t in tracks.values()}
downsampled_histories = {t_id: self.downsample_history(h) for t_id, h in track_histories.items()}
logger.debug(f'projected to world space')
# Sample data (coordinates and metadata)
# coordinates = [(1, 2, 'Point A'), (3, 4, 'Point B'), (5, 6, 'Point C'), (7, 8, 'Point D')]
all_points = []
indexed_track_ids: List[str] = []
for track_id, history in downsampled_histories.items():
all_points.extend([
[point[0], point[1]] for point in history
])
indexed_track_ids.extend([track_id] * len(history))
# self.flat_idx = self.flat_histories[:,2]
# Create the KD-Tree
tree = KDTree(all_points)
logger.debug('built tree')
return TrackHistoryState(
tracks, track_histories, indexed_track_ids, tree
)
def reload(self):
state = self.build_tree()
# aquire lock as brief as possible
with self.lock:
self.state = state
if self.cache_path:
with self.cache_path.open('wb') as fp:
logger.debug("Writing history to cache")
pickle.dump(self.state, fp)
def get_nearest_tracks(self, point: Coordinate, k:int, max_r: Optional[float] = np.inf):
with self.lock:
distances, indexes = self.state.tree.query(point, k, distance_upper_bound=max_r)
# filter out when there's no
indexes = indexes[distances != np.inf]
track_ids: Set[str] = {self.state.indexed_track_ids[idx] for idx in indexes}
# nearby_indexes = self.tree.query_ball_point(point, r)
# track_ids = set([self.flat_idx[idx] for idx in nearby_indexes])
return track_ids
def ids_as_trajectory(self, track_ids: Iterable[str]):
for track_id in track_ids:
yield self.state.tracks[track_id].get_projected_history(camera=self.camera)
@classmethod
def downsample_history(cls, history, cell_size=.3):
if not len(history):
return []
positions = np.unique(np.round(history / cell_size), axis=0) * cell_size
return positions
if __name__ == "__main__":
path = Path("EXPERIMENTS/raw/hof3/")
logging.basicConfig(level=logging.DEBUG)
calibration_path = Path("../DATASETS/hof3/calibration.json")
homography_path = Path("../DATASETS/hof3/homography.json")
camera = Camera.from_paths(calibration_path, homography_path, 12)
# device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
s = time.time()
history = TrackHistory(path, camera, Path("/tmp/historystate_hof3.pcl"))
dt = time.time() - s
print(f'loaded {len(history.state.tracks)} tracks in {dt}s')
track = list(history.state.tracks.values())[25]
trajectory_crop = TrackHistory.downsample_history(history.state.track_histories[track.track_id])
trajectory_org = track.get_projected_history(camera=camera)
target_point = trajectory_org[len(trajectory_org)//2+90]
import matplotlib.pyplot as plt # Visualization
track_set = history.get_nearest_tracks(target_point, 10, max_r=np.inf)
plt.gca().set_aspect('equal')
plt.scatter(trajectory_crop[:,0], trajectory_crop[:,1], c='orange')
plt.plot(trajectory_org[:,0], trajectory_org[:,1], c='blue', alpha=1)
plt.scatter(target_point[0], target_point[1], c='red', alpha=1)
for track_id in track_set:
closeby = history.state.tracks[track_id].get_projected_history(camera=camera)
plt.plot(closeby[:,0], closeby[:,1], c='green', alpha=.1)
plt.show()