186 lines
No EOL
6.1 KiB
Python
186 lines
No EOL
6.1 KiB
Python
from dataclasses import dataclass
|
|
import logging
|
|
from pathlib import Path
|
|
import pickle
|
|
from threading import Lock
|
|
import time
|
|
from typing import Dict, Iterable, List, Optional, Set
|
|
|
|
import numpy as np
|
|
from trap.base import Camera, Track
|
|
from trap.lines import Coordinate
|
|
from trap.tracker import FinalDisplacementFilter, Smoother, TrackReader
|
|
|
|
from scipy.spatial import KDTree
|
|
|
|
logger = logging.getLogger('history')
|
|
|
|
@dataclass
|
|
class TrackHistoryState():
|
|
"""
|
|
The lock of TrackHistory is not pickle-able so separate it into a separate state
|
|
"""
|
|
tracks: List[Track]
|
|
track_histories: Dict[str, np.ndarray]
|
|
indexed_track_ids: List[str]
|
|
tree: KDTree
|
|
|
|
|
|
|
|
class TrackHistory():
|
|
def __init__(self, path: Path, camera: Camera, cache_path: Optional[Path]):
|
|
self.path = path
|
|
self.camera = camera
|
|
self.cache_path = cache_path
|
|
self.lock = Lock()
|
|
self.load_from_cache() or self.reload()
|
|
|
|
|
|
def load_from_cache(self):
|
|
if self.cache_path is None:
|
|
return False
|
|
|
|
if self.cache_path.exists():
|
|
logger.debug("Load history state from cache")
|
|
with self.cache_path.open('rb') as fp:
|
|
try:
|
|
state = pickle.load(fp)
|
|
if not isinstance(state, TrackHistoryState):
|
|
raise RuntimeError("Pickled data is not a trackhistorystate")
|
|
self.state = state
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Cannot read cache {self.cache_path}: {e}")
|
|
|
|
return False
|
|
|
|
def build_tree(self):
|
|
reader = TrackReader(self.path, self.camera.fps)
|
|
logger.debug(f'loaded {len(reader)} tracks')
|
|
|
|
track_filter = FinalDisplacementFilter(2)
|
|
tracks = track_filter.apply(reader, self.camera)
|
|
logger.debug(f'after filtering left with {len(tracks)} tracks')
|
|
|
|
|
|
tracks: List[Track] = [t.get_with_interpolated_history() for t in tracks]
|
|
logger.debug(f'interpolated {len(tracks)} tracks')
|
|
|
|
# use convolution here, because precision does not matter and it is _way_ faster
|
|
smoother = Smoother(convolution=True)
|
|
tracks = [smoother.smooth_track(t) for t in tracks]
|
|
logger.debug(f'smoothed')
|
|
|
|
tracks = {track.track_id: track for track in tracks}
|
|
|
|
|
|
track_histories = {t.track_id: t.get_projected_history(camera=self.camera) for t in tracks.values()}
|
|
downsampled_histories = {t_id: self.downsample_history(h) for t_id, h in track_histories.items()}
|
|
logger.debug(f'projected to world space')
|
|
|
|
|
|
# Sample data (coordinates and metadata)
|
|
# coordinates = [(1, 2, 'Point A'), (3, 4, 'Point B'), (5, 6, 'Point C'), (7, 8, 'Point D')]
|
|
all_points = []
|
|
indexed_track_ids: List[str] = []
|
|
for track_id, history in downsampled_histories.items():
|
|
all_points.extend([
|
|
[point[0], point[1]] for point in history
|
|
])
|
|
indexed_track_ids.extend([track_id] * len(history))
|
|
|
|
# self.flat_idx = self.flat_histories[:,2]
|
|
|
|
# Create the KD-Tree
|
|
tree = KDTree(all_points)
|
|
|
|
logger.debug('built tree')
|
|
return TrackHistoryState(
|
|
tracks, track_histories, indexed_track_ids, tree
|
|
)
|
|
|
|
def reload(self):
|
|
state = self.build_tree()
|
|
|
|
# aquire lock as brief as possible
|
|
with self.lock:
|
|
self.state = state
|
|
|
|
|
|
if self.cache_path:
|
|
with self.cache_path.open('wb') as fp:
|
|
logger.debug("Writing history to cache")
|
|
pickle.dump(self.state, fp)
|
|
|
|
|
|
|
|
def get_nearest_tracks(self, point: Coordinate, k:int, max_r: Optional[float] = np.inf):
|
|
with self.lock:
|
|
distances, indexes = self.state.tree.query(point, k, distance_upper_bound=max_r)
|
|
# filter out when there's no
|
|
indexes = indexes[distances != np.inf]
|
|
track_ids: Set[str] = {self.state.indexed_track_ids[idx] for idx in indexes}
|
|
|
|
# nearby_indexes = self.tree.query_ball_point(point, r)
|
|
# track_ids = set([self.flat_idx[idx] for idx in nearby_indexes])
|
|
|
|
return track_ids
|
|
|
|
def ids_as_trajectory(self, track_ids: Iterable[str]):
|
|
for track_id in track_ids:
|
|
yield self.state.tracks[track_id].get_projected_history(camera=self.camera)
|
|
|
|
|
|
|
|
@classmethod
|
|
def downsample_history(cls, history, cell_size=.3):
|
|
|
|
|
|
if not len(history):
|
|
return []
|
|
|
|
positions = np.unique(np.round(history / cell_size), axis=0) * cell_size
|
|
|
|
return positions
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
path = Path("EXPERIMENTS/raw/hof3/")
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
calibration_path = Path("../DATASETS/hof3/calibration.json")
|
|
homography_path = Path("../DATASETS/hof3/homography.json")
|
|
camera = Camera.from_paths(calibration_path, homography_path, 12)
|
|
# device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
s = time.time()
|
|
history = TrackHistory(path, camera, Path("/tmp/historystate_hof3.pcl"))
|
|
dt = time.time() - s
|
|
print(f'loaded {len(history.state.tracks)} tracks in {dt}s')
|
|
|
|
|
|
track = list(history.state.tracks.values())[25]
|
|
trajectory_crop = TrackHistory.downsample_history(history.state.track_histories[track.track_id])
|
|
trajectory_org = track.get_projected_history(camera=camera)
|
|
target_point = trajectory_org[len(trajectory_org)//2+90]
|
|
|
|
import matplotlib.pyplot as plt # Visualization
|
|
|
|
track_set = history.get_nearest_tracks(target_point, 10, max_r=np.inf)
|
|
|
|
|
|
|
|
plt.gca().set_aspect('equal')
|
|
plt.scatter(trajectory_crop[:,0], trajectory_crop[:,1], c='orange')
|
|
plt.plot(trajectory_org[:,0], trajectory_org[:,1], c='blue', alpha=1)
|
|
plt.scatter(target_point[0], target_point[1], c='red', alpha=1)
|
|
for track_id in track_set:
|
|
closeby = history.state.tracks[track_id].get_projected_history(camera=camera)
|
|
plt.plot(closeby[:,0], closeby[:,1], c='green', alpha=.1)
|
|
|
|
plt.show() |