3624 lines
294 KiB
Text
3624 lines
294 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/ruben/suspicion/trap/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import torch\n",
|
||
"import matplotlib.pyplot as plt # Visualization \n",
|
||
"import torch.nn as nn\n",
|
||
"import pandas_helper_calc # noqa # provides df.calc.derivative()\n",
|
||
"import pandas as pd\n",
|
||
"import cv2\n",
|
||
"import pathlib\n",
|
||
"from tqdm.autonotebook import tqdm"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"FPS = 12\n",
|
||
"# SRC_CSV = \"EXPERIMENTS/hofext-maskrcnn/all.txt\"\n",
|
||
"# SRC_CSV = \"EXPERIMENTS/raw/generated/train/tracks.txt\"\n",
|
||
"SRC_CSV = \"EXPERIMENTS/raw/hof-meter-maskrcnn2/train/tracks.txt\"\n",
|
||
"SRC_CSV = \"EXPERIMENTS/20240426-hof-yolo/train/tracked.txt\"\n",
|
||
"SRC_CSV = \"EXPERIMENTS/raw/hof2/train/tracked.txt\"\n",
|
||
"# SRC_H = \"../DATASETS/hof/webcam20231103-2-homography.txt\"\n",
|
||
"SRC_H = None\n",
|
||
"CACHE_DIR = \"EXPERIMENTS/cache/hof2/\"\n",
|
||
"SMOOTHING = True # hof-yolo is already smoothed, hof2 isn't\n",
|
||
"SMOOTHING_WINDOW=3 #2"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"in_fields = ['proj_x', 'proj_y', 'vx', 'vy', 'ax', 'ay']\n",
|
||
"# out_fields = ['v', 'heading']\n",
|
||
"# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n",
|
||
"# an we can use simple MSE loss (I guess?)\n",
|
||
"out_fields = ['vx', 'vy']\n",
|
||
"window = int(FPS*1.5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"cuda\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Set device\n",
|
||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||
"print(device)\n",
|
||
"\n",
|
||
"# Hyperparameters\n",
|
||
"input_size = len(in_fields)\n",
|
||
"hidden_size = 256\n",
|
||
"num_layers = 3\n",
|
||
"output_size = len(out_fields)\n",
|
||
"learning_rate = 0.005 #0.01 #0.005\n",
|
||
"batch_size = 256\n",
|
||
"num_epochs = 1000"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cache_path = pathlib.Path(CACHE_DIR)\n",
|
||
"cache_path.mkdir(parents=True, exist_ok=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Samping 1/5, of 412098 items\n",
|
||
"Done sampling kept 83726 items\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"from trap.tools import load_tracks_from_csv\n",
|
||
"\n",
|
||
"data = load_tracks_from_csv(Path(SRC_CSV), FPS, 2, 5 )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>dx</th>\n",
|
||
" <th>dy</th>\n",
|
||
" <th>vx</th>\n",
|
||
" <th>vy</th>\n",
|
||
" <th>ax</th>\n",
|
||
" <th>ay</th>\n",
|
||
" <th>v</th>\n",
|
||
" <th>a</th>\n",
|
||
" <th>heading</th>\n",
|
||
" <th>d_heading</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>194</th>\n",
|
||
" <td>606.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1593.885864</td>\n",
|
||
" <td>782.814819</td>\n",
|
||
" <td>145.704346</td>\n",
|
||
" <td>195.380432</td>\n",
|
||
" <td>12.897830</td>\n",
|
||
" <td>10.750061</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.201965</td>\n",
|
||
" <td>-0.291350</td>\n",
|
||
" <td>0.484716</td>\n",
|
||
" <td>-0.699240</td>\n",
|
||
" <td>-1.622919</td>\n",
|
||
" <td>-1.732144</td>\n",
|
||
" <td>0.850815</td>\n",
|
||
" <td>1.399195</td>\n",
|
||
" <td>304.729842</td>\n",
|
||
" <td>-101.772559</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199</th>\n",
|
||
" <td>611.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1563.890015</td>\n",
|
||
" <td>700.710510</td>\n",
|
||
" <td>137.461304</td>\n",
|
||
" <td>190.194855</td>\n",
|
||
" <td>13.099794</td>\n",
|
||
" <td>10.458712</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.201965</td>\n",
|
||
" <td>-0.291350</td>\n",
|
||
" <td>0.484716</td>\n",
|
||
" <td>-0.699240</td>\n",
|
||
" <td>-1.622919</td>\n",
|
||
" <td>-1.732144</td>\n",
|
||
" <td>0.850815</td>\n",
|
||
" <td>1.399195</td>\n",
|
||
" <td>304.729842</td>\n",
|
||
" <td>-101.772559</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>204</th>\n",
|
||
" <td>616.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1529.469727</td>\n",
|
||
" <td>635.622498</td>\n",
|
||
" <td>129.342651</td>\n",
|
||
" <td>194.191528</td>\n",
|
||
" <td>13.020002</td>\n",
|
||
" <td>9.866642</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.079792</td>\n",
|
||
" <td>-0.592069</td>\n",
|
||
" <td>-0.191501</td>\n",
|
||
" <td>-1.420966</td>\n",
|
||
" <td>-1.622919</td>\n",
|
||
" <td>-1.732144</td>\n",
|
||
" <td>1.433812</td>\n",
|
||
" <td>1.399195</td>\n",
|
||
" <td>262.324609</td>\n",
|
||
" <td>-101.772559</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>209</th>\n",
|
||
" <td>621.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1474.449341</td>\n",
|
||
" <td>569.387634</td>\n",
|
||
" <td>128.099854</td>\n",
|
||
" <td>199.766357</td>\n",
|
||
" <td>12.965776</td>\n",
|
||
" <td>9.301442</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.054226</td>\n",
|
||
" <td>-0.565200</td>\n",
|
||
" <td>-0.130143</td>\n",
|
||
" <td>-1.356479</td>\n",
|
||
" <td>0.147259</td>\n",
|
||
" <td>0.154769</td>\n",
|
||
" <td>1.362708</td>\n",
|
||
" <td>-0.170650</td>\n",
|
||
" <td>264.519715</td>\n",
|
||
" <td>5.268254</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>214</th>\n",
|
||
" <td>626.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1443.123535</td>\n",
|
||
" <td>518.907043</td>\n",
|
||
" <td>120.022461</td>\n",
|
||
" <td>202.566772</td>\n",
|
||
" <td>12.642992</td>\n",
|
||
" <td>8.976624</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.322784</td>\n",
|
||
" <td>-0.324818</td>\n",
|
||
" <td>-0.774681</td>\n",
|
||
" <td>-0.779564</td>\n",
|
||
" <td>-1.546892</td>\n",
|
||
" <td>1.384597</td>\n",
|
||
" <td>1.099023</td>\n",
|
||
" <td>-0.632844</td>\n",
|
||
" <td>225.179993</td>\n",
|
||
" <td>-94.415332</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>219</th>\n",
|
||
" <td>631.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1398.944946</td>\n",
|
||
" <td>461.813049</td>\n",
|
||
" <td>106.391357</td>\n",
|
||
" <td>193.476410</td>\n",
|
||
" <td>12.465588</td>\n",
|
||
" <td>8.557788</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.177404</td>\n",
|
||
" <td>-0.418836</td>\n",
|
||
" <td>-0.425771</td>\n",
|
||
" <td>-1.005205</td>\n",
|
||
" <td>0.837386</td>\n",
|
||
" <td>-0.541539</td>\n",
|
||
" <td>1.091659</td>\n",
|
||
" <td>-0.017675</td>\n",
|
||
" <td>247.044148</td>\n",
|
||
" <td>52.473972</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>224</th>\n",
|
||
" <td>636.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1353.237793</td>\n",
|
||
" <td>438.118896</td>\n",
|
||
" <td>91.444336</td>\n",
|
||
" <td>170.930664</td>\n",
|
||
" <td>12.128433</td>\n",
|
||
" <td>8.052323</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.337155</td>\n",
|
||
" <td>-0.505465</td>\n",
|
||
" <td>-0.809172</td>\n",
|
||
" <td>-1.213117</td>\n",
|
||
" <td>-0.920163</td>\n",
|
||
" <td>-0.498987</td>\n",
|
||
" <td>1.458222</td>\n",
|
||
" <td>0.879752</td>\n",
|
||
" <td>236.295957</td>\n",
|
||
" <td>-25.795658</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>229</th>\n",
|
||
" <td>641.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1272.791992</td>\n",
|
||
" <td>408.827759</td>\n",
|
||
" <td>104.274536</td>\n",
|
||
" <td>180.414551</td>\n",
|
||
" <td>11.689648</td>\n",
|
||
" <td>7.684636</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.438785</td>\n",
|
||
" <td>-0.367687</td>\n",
|
||
" <td>-1.053084</td>\n",
|
||
" <td>-0.882448</td>\n",
|
||
" <td>-0.585388</td>\n",
|
||
" <td>0.793604</td>\n",
|
||
" <td>1.373936</td>\n",
|
||
" <td>-0.202286</td>\n",
|
||
" <td>219.961870</td>\n",
|
||
" <td>-39.201809</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>234</th>\n",
|
||
" <td>646.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1198.965820</td>\n",
|
||
" <td>407.952759</td>\n",
|
||
" <td>103.282104</td>\n",
|
||
" <td>167.306580</td>\n",
|
||
" <td>11.207276</td>\n",
|
||
" <td>7.476216</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.482372</td>\n",
|
||
" <td>-0.208420</td>\n",
|
||
" <td>-1.157693</td>\n",
|
||
" <td>-0.500209</td>\n",
|
||
" <td>-0.251064</td>\n",
|
||
" <td>0.917374</td>\n",
|
||
" <td>1.261136</td>\n",
|
||
" <td>-0.270721</td>\n",
|
||
" <td>203.367915</td>\n",
|
||
" <td>-39.825493</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>239</th>\n",
|
||
" <td>651.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1156.309570</td>\n",
|
||
" <td>415.743408</td>\n",
|
||
" <td>97.628784</td>\n",
|
||
" <td>158.774811</td>\n",
|
||
" <td>10.884154</td>\n",
|
||
" <td>7.514692</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.323122</td>\n",
|
||
" <td>0.038476</td>\n",
|
||
" <td>-0.775493</td>\n",
|
||
" <td>0.092343</td>\n",
|
||
" <td>0.917282</td>\n",
|
||
" <td>1.422125</td>\n",
|
||
" <td>0.780971</td>\n",
|
||
" <td>-1.152395</td>\n",
|
||
" <td>173.209381</td>\n",
|
||
" <td>-72.380481</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>244</th>\n",
|
||
" <td>656.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1094.440430</td>\n",
|
||
" <td>443.849915</td>\n",
|
||
" <td>107.938110</td>\n",
|
||
" <td>177.703979</td>\n",
|
||
" <td>10.544492</td>\n",
|
||
" <td>7.870090</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.339661</td>\n",
|
||
" <td>0.355398</td>\n",
|
||
" <td>-0.815187</td>\n",
|
||
" <td>0.852955</td>\n",
|
||
" <td>-0.095267</td>\n",
|
||
" <td>1.825468</td>\n",
|
||
" <td>1.179857</td>\n",
|
||
" <td>0.957326</td>\n",
|
||
" <td>133.703018</td>\n",
|
||
" <td>-94.815270</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>249</th>\n",
|
||
" <td>661.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1072.595093</td>\n",
|
||
" <td>481.461945</td>\n",
|
||
" <td>118.452148</td>\n",
|
||
" <td>205.365173</td>\n",
|
||
" <td>10.486504</td>\n",
|
||
" <td>8.287758</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.057989</td>\n",
|
||
" <td>0.417668</td>\n",
|
||
" <td>-0.139173</td>\n",
|
||
" <td>1.002404</td>\n",
|
||
" <td>1.622435</td>\n",
|
||
" <td>0.358678</td>\n",
|
||
" <td>1.012019</td>\n",
|
||
" <td>-0.402811</td>\n",
|
||
" <td>97.904355</td>\n",
|
||
" <td>-85.916792</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>254</th>\n",
|
||
" <td>666.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1086.627930</td>\n",
|
||
" <td>526.733154</td>\n",
|
||
" <td>105.444458</td>\n",
|
||
" <td>189.750610</td>\n",
|
||
" <td>10.498393</td>\n",
|
||
" <td>8.684043</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.011889</td>\n",
|
||
" <td>0.396285</td>\n",
|
||
" <td>0.028534</td>\n",
|
||
" <td>0.951083</td>\n",
|
||
" <td>0.402496</td>\n",
|
||
" <td>-0.123170</td>\n",
|
||
" <td>0.951511</td>\n",
|
||
" <td>-0.145220</td>\n",
|
||
" <td>88.281546</td>\n",
|
||
" <td>-23.094741</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>259</th>\n",
|
||
" <td>671.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1099.592285</td>\n",
|
||
" <td>584.216675</td>\n",
|
||
" <td>114.395874</td>\n",
|
||
" <td>218.003479</td>\n",
|
||
" <td>10.492767</td>\n",
|
||
" <td>9.267106</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.005626</td>\n",
|
||
" <td>0.583063</td>\n",
|
||
" <td>-0.013502</td>\n",
|
||
" <td>1.399352</td>\n",
|
||
" <td>-0.100887</td>\n",
|
||
" <td>1.075845</td>\n",
|
||
" <td>1.399417</td>\n",
|
||
" <td>1.074975</td>\n",
|
||
" <td>90.552815</td>\n",
|
||
" <td>5.451045</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>264</th>\n",
|
||
" <td>676.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1144.484782</td>\n",
|
||
" <td>642.779582</td>\n",
|
||
" <td>96.750326</td>\n",
|
||
" <td>180.744690</td>\n",
|
||
" <td>10.484691</td>\n",
|
||
" <td>9.582745</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>-0.008077</td>\n",
|
||
" <td>0.315639</td>\n",
|
||
" <td>-0.019384</td>\n",
|
||
" <td>0.757534</td>\n",
|
||
" <td>-0.014116</td>\n",
|
||
" <td>-1.540364</td>\n",
|
||
" <td>0.757782</td>\n",
|
||
" <td>-1.539925</td>\n",
|
||
" <td>91.465753</td>\n",
|
||
" <td>2.191052</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>269</th>\n",
|
||
" <td>681.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1179.532959</td>\n",
|
||
" <td>682.365540</td>\n",
|
||
" <td>107.764282</td>\n",
|
||
" <td>200.651733</td>\n",
|
||
" <td>10.698373</td>\n",
|
||
" <td>9.950516</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.213682</td>\n",
|
||
" <td>0.367771</td>\n",
|
||
" <td>0.512837</td>\n",
|
||
" <td>0.882650</td>\n",
|
||
" <td>1.277331</td>\n",
|
||
" <td>0.300278</td>\n",
|
||
" <td>1.020820</td>\n",
|
||
" <td>0.631291</td>\n",
|
||
" <td>59.842534</td>\n",
|
||
" <td>-75.895726</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>16 rows × 24 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"194 606.0 4 1593.885864 782.814819 145.704346 195.380432 \n",
|
||
"199 611.0 4 1563.890015 700.710510 137.461304 190.194855 \n",
|
||
"204 616.0 4 1529.469727 635.622498 129.342651 194.191528 \n",
|
||
"209 621.0 4 1474.449341 569.387634 128.099854 199.766357 \n",
|
||
"214 626.0 4 1443.123535 518.907043 120.022461 202.566772 \n",
|
||
"219 631.0 4 1398.944946 461.813049 106.391357 193.476410 \n",
|
||
"224 636.0 4 1353.237793 438.118896 91.444336 170.930664 \n",
|
||
"229 641.0 4 1272.791992 408.827759 104.274536 180.414551 \n",
|
||
"234 646.0 4 1198.965820 407.952759 103.282104 167.306580 \n",
|
||
"239 651.0 4 1156.309570 415.743408 97.628784 158.774811 \n",
|
||
"244 656.0 4 1094.440430 443.849915 107.938110 177.703979 \n",
|
||
"249 661.0 4 1072.595093 481.461945 118.452148 205.365173 \n",
|
||
"254 666.0 4 1086.627930 526.733154 105.444458 189.750610 \n",
|
||
"259 671.0 4 1099.592285 584.216675 114.395874 218.003479 \n",
|
||
"264 676.0 4 1144.484782 642.779582 96.750326 180.744690 \n",
|
||
"269 681.0 4 1179.532959 682.365540 107.764282 200.651733 \n",
|
||
"\n",
|
||
" x y state diff ... dx dy vx \\\n",
|
||
"194 12.897830 10.750061 2.0 NaN ... 0.201965 -0.291350 0.484716 \n",
|
||
"199 13.099794 10.458712 1.0 5.0 ... 0.201965 -0.291350 0.484716 \n",
|
||
"204 13.020002 9.866642 2.0 5.0 ... -0.079792 -0.592069 -0.191501 \n",
|
||
"209 12.965776 9.301442 2.0 5.0 ... -0.054226 -0.565200 -0.130143 \n",
|
||
"214 12.642992 8.976624 2.0 5.0 ... -0.322784 -0.324818 -0.774681 \n",
|
||
"219 12.465588 8.557788 2.0 5.0 ... -0.177404 -0.418836 -0.425771 \n",
|
||
"224 12.128433 8.052323 2.0 5.0 ... -0.337155 -0.505465 -0.809172 \n",
|
||
"229 11.689648 7.684636 2.0 5.0 ... -0.438785 -0.367687 -1.053084 \n",
|
||
"234 11.207276 7.476216 2.0 5.0 ... -0.482372 -0.208420 -1.157693 \n",
|
||
"239 10.884154 7.514692 2.0 5.0 ... -0.323122 0.038476 -0.775493 \n",
|
||
"244 10.544492 7.870090 2.0 5.0 ... -0.339661 0.355398 -0.815187 \n",
|
||
"249 10.486504 8.287758 2.0 5.0 ... -0.057989 0.417668 -0.139173 \n",
|
||
"254 10.498393 8.684043 2.0 5.0 ... 0.011889 0.396285 0.028534 \n",
|
||
"259 10.492767 9.267106 2.0 5.0 ... -0.005626 0.583063 -0.013502 \n",
|
||
"264 10.484691 9.582745 1.0 5.0 ... -0.008077 0.315639 -0.019384 \n",
|
||
"269 10.698373 9.950516 2.0 5.0 ... 0.213682 0.367771 0.512837 \n",
|
||
"\n",
|
||
" vy ax ay v a heading d_heading \n",
|
||
"194 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
|
||
"199 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
|
||
"204 -1.420966 -1.622919 -1.732144 1.433812 1.399195 262.324609 -101.772559 \n",
|
||
"209 -1.356479 0.147259 0.154769 1.362708 -0.170650 264.519715 5.268254 \n",
|
||
"214 -0.779564 -1.546892 1.384597 1.099023 -0.632844 225.179993 -94.415332 \n",
|
||
"219 -1.005205 0.837386 -0.541539 1.091659 -0.017675 247.044148 52.473972 \n",
|
||
"224 -1.213117 -0.920163 -0.498987 1.458222 0.879752 236.295957 -25.795658 \n",
|
||
"229 -0.882448 -0.585388 0.793604 1.373936 -0.202286 219.961870 -39.201809 \n",
|
||
"234 -0.500209 -0.251064 0.917374 1.261136 -0.270721 203.367915 -39.825493 \n",
|
||
"239 0.092343 0.917282 1.422125 0.780971 -1.152395 173.209381 -72.380481 \n",
|
||
"244 0.852955 -0.095267 1.825468 1.179857 0.957326 133.703018 -94.815270 \n",
|
||
"249 1.002404 1.622435 0.358678 1.012019 -0.402811 97.904355 -85.916792 \n",
|
||
"254 0.951083 0.402496 -0.123170 0.951511 -0.145220 88.281546 -23.094741 \n",
|
||
"259 1.399352 -0.100887 1.075845 1.399417 1.074975 90.552815 5.451045 \n",
|
||
"264 0.757534 -0.014116 -1.540364 0.757782 -1.539925 91.465753 2.191052 \n",
|
||
"269 0.882650 1.277331 0.300278 1.020820 0.631291 59.842534 -75.895726 \n",
|
||
"\n",
|
||
"[16 rows x 24 columns]"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"5\" valign=\"top\">1</th>\n",
|
||
" <th>342</th>\n",
|
||
" <td>1393.736572</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>67.613647</td>\n",
|
||
" <td>121.391151</td>\n",
|
||
" <td>1363.3164</td>\n",
|
||
" <td>232.92647</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>343</th>\n",
|
||
" <td>1391.775879</td>\n",
|
||
" <td>0.852371</td>\n",
|
||
" <td>78.562622</td>\n",
|
||
" <td>141.050934</td>\n",
|
||
" <td>1359.1885</td>\n",
|
||
" <td>266.06586</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>346</th>\n",
|
||
" <td>1392.164551</td>\n",
|
||
" <td>7.758987</td>\n",
|
||
" <td>85.757324</td>\n",
|
||
" <td>154.357971</td>\n",
|
||
" <td>1355.7444</td>\n",
|
||
" <td>297.67404</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>347</th>\n",
|
||
" <td>1393.844849</td>\n",
|
||
" <td>12.691238</td>\n",
|
||
" <td>86.482910</td>\n",
|
||
" <td>156.264786</td>\n",
|
||
" <td>1355.2312</td>\n",
|
||
" <td>308.20670</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>348</th>\n",
|
||
" <td>1394.839111</td>\n",
|
||
" <td>15.621338</td>\n",
|
||
" <td>84.763428</td>\n",
|
||
" <td>154.584396</td>\n",
|
||
" <td>1354.9246</td>\n",
|
||
" <td>310.09225</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th rowspan=\"5\" valign=\"top\">5030</th>\n",
|
||
" <th>32691</th>\n",
|
||
" <td>1708.213379</td>\n",
|
||
" <td>749.260376</td>\n",
|
||
" <td>133.839966</td>\n",
|
||
" <td>182.405396</td>\n",
|
||
" <td>1402.5426</td>\n",
|
||
" <td>1075.20870</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32692</th>\n",
|
||
" <td>1707.651855</td>\n",
|
||
" <td>748.997437</td>\n",
|
||
" <td>134.013672</td>\n",
|
||
" <td>182.391296</td>\n",
|
||
" <td>1402.2948</td>\n",
|
||
" <td>1074.97230</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32720</th>\n",
|
||
" <td>1700.379639</td>\n",
|
||
" <td>750.314697</td>\n",
|
||
" <td>128.792603</td>\n",
|
||
" <td>181.589783</td>\n",
|
||
" <td>1395.7992</td>\n",
|
||
" <td>1074.27320</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32721</th>\n",
|
||
" <td>1701.722412</td>\n",
|
||
" <td>751.000488</td>\n",
|
||
" <td>125.286865</td>\n",
|
||
" <td>180.867615</td>\n",
|
||
" <td>1395.5424</td>\n",
|
||
" <td>1074.20560</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32722</th>\n",
|
||
" <td>1702.384766</td>\n",
|
||
" <td>750.754517</td>\n",
|
||
" <td>123.435425</td>\n",
|
||
" <td>180.945618</td>\n",
|
||
" <td>1395.4082</td>\n",
|
||
" <td>1074.06500</td>\n",
|
||
" <td>2</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>326960 rows × 7 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" l t w h x \\\n",
|
||
"track_id frame_id \n",
|
||
"1 342 1393.736572 0.000000 67.613647 121.391151 1363.3164 \n",
|
||
" 343 1391.775879 0.852371 78.562622 141.050934 1359.1885 \n",
|
||
" 346 1392.164551 7.758987 85.757324 154.357971 1355.7444 \n",
|
||
" 347 1393.844849 12.691238 86.482910 156.264786 1355.2312 \n",
|
||
" 348 1394.839111 15.621338 84.763428 154.584396 1354.9246 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"5030 32691 1708.213379 749.260376 133.839966 182.405396 1402.5426 \n",
|
||
" 32692 1707.651855 748.997437 134.013672 182.391296 1402.2948 \n",
|
||
" 32720 1700.379639 750.314697 128.792603 181.589783 1395.7992 \n",
|
||
" 32721 1701.722412 751.000488 125.286865 180.867615 1395.5424 \n",
|
||
" 32722 1702.384766 750.754517 123.435425 180.945618 1395.4082 \n",
|
||
"\n",
|
||
" y state \n",
|
||
"track_id frame_id \n",
|
||
"1 342 232.92647 2 \n",
|
||
" 343 266.06586 2 \n",
|
||
" 346 297.67404 2 \n",
|
||
" 347 308.20670 2 \n",
|
||
" 348 310.09225 2 \n",
|
||
"... ... ... \n",
|
||
"5030 32691 1075.20870 2 \n",
|
||
" 32692 1074.97230 2 \n",
|
||
" 32720 1074.27320 2 \n",
|
||
" 32721 1074.20560 2 \n",
|
||
" 32722 1074.06500 2 \n",
|
||
"\n",
|
||
"[326960 rows x 7 columns]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"data = pd.read_csv(SRC_CSV, delimiter=\"\\t\", index_col=False, header=None)\n",
|
||
"# data.columns = ['frame_id', 'track_id', 'pos_x', 'pos_y', 'width', 'height']#, '_x', '_y,']\n",
|
||
"data.columns = ['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state']#, '_x', '_y,']\n",
|
||
"data['frame_id'] = pd.to_numeric(data['frame_id'], downcast='integer')\n",
|
||
"data['frame_id'] = data['frame_id'] // 10 # compatibility with Trajectron++\n",
|
||
"\n",
|
||
"data.sort_values(by=['track_id', 'frame_id'],inplace=True)\n",
|
||
"\n",
|
||
"data.set_index(['track_id', 'frame_id'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# cm to meter\n",
|
||
"data['x'] = data['x']/100\n",
|
||
"data['y'] = data['y']/100"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff() #.fillna(0)\n",
|
||
"data['diff'] = pd.to_numeric(data['diff'], downcast='integer')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"326960it [06:37, 821.55it/s] "
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"was: 326960 added: 85138 new length: 412098\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"missing=0\n",
|
||
"old_size=len(data)\n",
|
||
"# slow way to append missing steps to the dataset\n",
|
||
"for ind, row in tqdm(data.iterrows()):\n",
|
||
" if row['diff'] > 1:\n",
|
||
" for s in range(1, int(row['diff'])):\n",
|
||
" # add as many entries as missing\n",
|
||
" missing += 1\n",
|
||
" data.loc[len(data)] = [row['frame_id']-s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, 1]\n",
|
||
" # new_frame = [data.loc[ind-1]['frame_id']+s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan]\n",
|
||
" # data.loc[len(data)] = new_frame\n",
|
||
"\n",
|
||
"print('was:', old_size, 'added:', missing, 'new length:', len(data))\n",
|
||
"# now sort, so that the added data is in the right place\n",
|
||
"data.sort_values(by=['track_id', 'frame_id'], inplace=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# interpolate missing data\n",
|
||
"df=data.copy()\n",
|
||
"df = df.groupby('track_id').apply(lambda group: group.interpolate(method='linear'))\n",
|
||
"df.reset_index(drop=True, inplace=True)\n",
|
||
"data = df\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Running smoother\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from trap.tracker import Smoother\n",
|
||
"\n",
|
||
"if SMOOTHING:\n",
|
||
" df=data.copy()\n",
|
||
" if 'x_raw' not in df:\n",
|
||
" df['x_raw'] = df['x']\n",
|
||
" if 'y_raw' not in df:\n",
|
||
" df['y_raw'] = df['y']\n",
|
||
"\n",
|
||
" print(\"Running smoother\")\n",
|
||
" # print(df)\n",
|
||
" # from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother\n",
|
||
" smoother = Smoother(convolution=False)\n",
|
||
" def smoothing(data):\n",
|
||
" # smoother = ConvolutionSmoother(window_len=SMOOTHING_WINDOW, window_type='ones', copy=None)\n",
|
||
" return smoother.smooth(data).tolist()\n",
|
||
" # df=df.assign(smooth_data=smoother.smooth_data[0])\n",
|
||
" # return smoother.smooth_data[0].tolist()\n",
|
||
"\n",
|
||
" # operate smoothing per axis\n",
|
||
" df['x'] = df.groupby('track_id')['x_raw'].transform(smoothing)\n",
|
||
" df['y'] = df.groupby('track_id')['y_raw'].transform(smoothing)\n",
|
||
" \n",
|
||
"\n",
|
||
" data = df\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>x_raw</th>\n",
|
||
" <th>y_raw</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>0.881595</td>\n",
|
||
" <td>7.341152</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.870703</td>\n",
|
||
" <td>7.309168</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.901374</td>\n",
|
||
" <td>7.370044</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.924360</td>\n",
|
||
" <td>7.432365</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.906583</td>\n",
|
||
" <td>7.456334</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632.0</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>15.214551</td>\n",
|
||
" <td>10.027093</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632.0</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.244872</td>\n",
|
||
" <td>10.047117</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632.0</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.318496</td>\n",
|
||
" <td>10.015218</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632.0</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.400203</td>\n",
|
||
" <td>9.935355</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632.0</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.416893</td>\n",
|
||
" <td>10.051785</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 12 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1.0 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1.0 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1.0 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1.0 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1.0 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632.0 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632.0 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632.0 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632.0 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632.0 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff x_raw y_raw \n",
|
||
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 \n",
|
||
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 \n",
|
||
"\n",
|
||
"[320188 rows x 12 columns]"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# del data['diff']\n",
|
||
"# recalculate diff\n",
|
||
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff()\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>x_raw</th>\n",
|
||
" <th>y_raw</th>\n",
|
||
" <th>dt</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>0.881595</td>\n",
|
||
" <td>7.341152</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.870703</td>\n",
|
||
" <td>7.309168</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.901374</td>\n",
|
||
" <td>7.370044</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.924360</td>\n",
|
||
" <td>7.432365</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.906583</td>\n",
|
||
" <td>7.456334</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>15.214551</td>\n",
|
||
" <td>10.027093</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.244872</td>\n",
|
||
" <td>10.047117</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.318496</td>\n",
|
||
" <td>10.015218</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.400203</td>\n",
|
||
" <td>9.935355</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.416893</td>\n",
|
||
" <td>10.051785</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 13 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff x_raw y_raw dt \n",
|
||
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
|
||
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
|
||
"\n",
|
||
"[320188 rows x 13 columns]"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"# data['node_type'] = 'PEDESTRIAN' # compatibility with Trajectron++\n",
|
||
"# data['node_id'] = data['track_id'].astype(str)\n",
|
||
"data['track_id'] = pd.to_numeric(data['track_id'], downcast='integer')\n",
|
||
"\n",
|
||
"\n",
|
||
"data['dt'] = data['diff'] * (1/FPS)\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# position into an average coordinate system. (DO THESE NEED TO BE STORED?)\n",
|
||
"# Don't do this, messes up\n",
|
||
"# data['pos_x'] = data['pos_x'] - data['pos_x'].mean()\n",
|
||
"# data['pos_y'] = data['pos_y'] - data['pos_y'].mean()\n",
|
||
" \n",
|
||
"# data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<AxesSubplot:>"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"data['diff'].hist()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The dataset is a bit crappy because it has different frame step: ranging from predominantly 1 and 2 to sometimes have 3 and 4 as well. This inevitabily leads to difference in speed caluclations"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"if SRC_H is not None:\n",
|
||
" H = np.loadtxt(SRC_H, delimiter=',')\n",
|
||
"else:\n",
|
||
" H = None"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"No H given, probably already projected data?\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>x_raw</th>\n",
|
||
" <th>y_raw</th>\n",
|
||
" <th>dt</th>\n",
|
||
" <th>proj_x</th>\n",
|
||
" <th>proj_y</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>0.881595</td>\n",
|
||
" <td>7.341152</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.870703</td>\n",
|
||
" <td>7.309168</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.901374</td>\n",
|
||
" <td>7.370044</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.924360</td>\n",
|
||
" <td>7.432365</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.906583</td>\n",
|
||
" <td>7.456334</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>15.214551</td>\n",
|
||
" <td>10.027093</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.244872</td>\n",
|
||
" <td>10.047117</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.318496</td>\n",
|
||
" <td>10.015218</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.400203</td>\n",
|
||
" <td>9.935355</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>15.416893</td>\n",
|
||
" <td>10.051785</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 15 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff x_raw y_raw dt \\\n",
|
||
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
|
||
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
|
||
"\n",
|
||
" proj_x proj_y \n",
|
||
"0 0.855100 7.136193 \n",
|
||
"1 0.873132 7.235233 \n",
|
||
"2 0.890957 7.328989 \n",
|
||
"3 0.907784 7.418187 \n",
|
||
"4 0.923439 7.505012 \n",
|
||
"... ... ... \n",
|
||
"320183 14.840476 9.786501 \n",
|
||
"320184 15.033432 9.870472 \n",
|
||
"320185 15.211560 9.943236 \n",
|
||
"320186 15.377673 10.008965 \n",
|
||
"320187 15.538255 10.075935 \n",
|
||
"\n",
|
||
"[320188 rows x 15 columns]"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"if H is not None:\n",
|
||
" print(\"Projecting data\")\n",
|
||
" data['foot_x'] = data['pos_x'] + 0.5 * data['width']\n",
|
||
" data['foot_y'] = data['pos_y'] + 0.5 * data['height']\n",
|
||
" \n",
|
||
" transformed = cv2.perspectiveTransform(np.array([data[['foot_x','foot_y']].to_numpy()]),H)[0]\n",
|
||
" data['proj_x'], data['proj_y'] = transformed[:,0], transformed[:,1]\n",
|
||
" data['proj_x'] = data['proj_x'].div(100) # cm to m\n",
|
||
" data['proj_y'] = data['proj_y'].div(100) # cm to m\n",
|
||
" # and shift to mean (THES NEED TO BE STORED AND REUSED IN LIVE SETTING)\n",
|
||
" mean_x = data['proj_x'].mean()\n",
|
||
" mean_y = data['proj_y'].mean()\n",
|
||
" data['proj_x'] = data['proj_x'] - data['proj_x'].mean()\n",
|
||
" data['proj_y'] = data['proj_y'] - data['proj_y'].mean()\n",
|
||
"else:\n",
|
||
" print(\"No H given, probably already projected data?\")\n",
|
||
" mean_x = 0\n",
|
||
" mean_y = 0\n",
|
||
" data['proj_x'] = data['x']\n",
|
||
" data['proj_y'] = data['y']\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Deriving displacement, velocity and accelation from x and y\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>y_raw</th>\n",
|
||
" <th>dt</th>\n",
|
||
" <th>proj_x</th>\n",
|
||
" <th>proj_y</th>\n",
|
||
" <th>dx</th>\n",
|
||
" <th>dy</th>\n",
|
||
" <th>vx</th>\n",
|
||
" <th>vy</th>\n",
|
||
" <th>ax</th>\n",
|
||
" <th>ay</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>7.341152</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>7.309168</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>0.018032</td>\n",
|
||
" <td>0.099039</td>\n",
|
||
" <td>0.216383</td>\n",
|
||
" <td>1.188473</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>7.370044</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>0.017825</td>\n",
|
||
" <td>0.093756</td>\n",
|
||
" <td>0.213899</td>\n",
|
||
" <td>1.125077</td>\n",
|
||
" <td>-0.029812</td>\n",
|
||
" <td>-0.760753</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>7.432365</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>0.016827</td>\n",
|
||
" <td>0.089198</td>\n",
|
||
" <td>0.201924</td>\n",
|
||
" <td>1.070371</td>\n",
|
||
" <td>-0.143699</td>\n",
|
||
" <td>-0.656466</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>7.456334</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>0.015655</td>\n",
|
||
" <td>0.086825</td>\n",
|
||
" <td>0.187865</td>\n",
|
||
" <td>1.041902</td>\n",
|
||
" <td>-0.168701</td>\n",
|
||
" <td>-0.341637</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>10.027093</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>10.047117</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>0.192955</td>\n",
|
||
" <td>0.083971</td>\n",
|
||
" <td>2.315463</td>\n",
|
||
" <td>1.007656</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>10.015218</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>0.178128</td>\n",
|
||
" <td>0.072764</td>\n",
|
||
" <td>2.137542</td>\n",
|
||
" <td>0.873173</td>\n",
|
||
" <td>-2.135059</td>\n",
|
||
" <td>-1.613797</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>9.935355</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>0.166113</td>\n",
|
||
" <td>0.065728</td>\n",
|
||
" <td>1.993352</td>\n",
|
||
" <td>0.788742</td>\n",
|
||
" <td>-1.730279</td>\n",
|
||
" <td>-1.013172</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>10.051785</td>\n",
|
||
" <td>0.083333</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>0.160582</td>\n",
|
||
" <td>0.066971</td>\n",
|
||
" <td>1.926987</td>\n",
|
||
" <td>0.803649</td>\n",
|
||
" <td>-0.796376</td>\n",
|
||
" <td>0.178886</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 21 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff ... y_raw dt \\\n",
|
||
"0 0.855100 7.136193 2.0 NaN ... 7.341152 NaN \n",
|
||
"1 0.873132 7.235233 2.0 1.0 ... 7.309168 0.083333 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 ... 7.370044 0.083333 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 ... 7.432365 0.083333 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 ... 7.456334 0.083333 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN ... 10.027093 NaN \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 ... 10.047117 0.083333 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 ... 10.015218 0.083333 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 ... 9.935355 0.083333 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 ... 10.051785 0.083333 \n",
|
||
"\n",
|
||
" proj_x proj_y dx dy vx vy \\\n",
|
||
"0 0.855100 7.136193 NaN NaN NaN NaN \n",
|
||
"1 0.873132 7.235233 0.018032 0.099039 0.216383 1.188473 \n",
|
||
"2 0.890957 7.328989 0.017825 0.093756 0.213899 1.125077 \n",
|
||
"3 0.907784 7.418187 0.016827 0.089198 0.201924 1.070371 \n",
|
||
"4 0.923439 7.505012 0.015655 0.086825 0.187865 1.041902 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 NaN NaN NaN NaN \n",
|
||
"320184 15.033432 9.870472 0.192955 0.083971 2.315463 1.007656 \n",
|
||
"320185 15.211560 9.943236 0.178128 0.072764 2.137542 0.873173 \n",
|
||
"320186 15.377673 10.008965 0.166113 0.065728 1.993352 0.788742 \n",
|
||
"320187 15.538255 10.075935 0.160582 0.066971 1.926987 0.803649 \n",
|
||
"\n",
|
||
" ax ay \n",
|
||
"0 NaN NaN \n",
|
||
"1 NaN NaN \n",
|
||
"2 -0.029812 -0.760753 \n",
|
||
"3 -0.143699 -0.656466 \n",
|
||
"4 -0.168701 -0.341637 \n",
|
||
"... ... ... \n",
|
||
"320183 NaN NaN \n",
|
||
"320184 NaN NaN \n",
|
||
"320185 -2.135059 -1.613797 \n",
|
||
"320186 -1.730279 -1.013172 \n",
|
||
"320187 -0.796376 0.178886 \n",
|
||
"\n",
|
||
"[320188 rows x 21 columns]"
|
||
]
|
||
},
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"Deriving displacement, velocity and accelation from x and y\")\n",
|
||
"data['dx'] = data.groupby(['track_id'])['proj_x'].diff()\n",
|
||
"data['dy'] = data.groupby(['track_id'])['proj_y'].diff()\n",
|
||
"data['vx'] = data['dx'].div(data['dt'], axis=0)\n",
|
||
"data['vy'] = data['dy'].div(data['dt'], axis=0)\n",
|
||
"\n",
|
||
"data['ax'] = data.groupby(['track_id'])['vx'].diff().div(data['dt'], axis=0)\n",
|
||
"data['ay'] = data.groupby(['track_id'])['vy'].diff().div(data['dt'], axis=0)\n",
|
||
"\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>dx</th>\n",
|
||
" <th>dy</th>\n",
|
||
" <th>vx</th>\n",
|
||
" <th>vy</th>\n",
|
||
" <th>ax</th>\n",
|
||
" <th>ay</th>\n",
|
||
" <th>v</th>\n",
|
||
" <th>a</th>\n",
|
||
" <th>heading</th>\n",
|
||
" <th>d_heading</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.018032</td>\n",
|
||
" <td>0.099039</td>\n",
|
||
" <td>0.216383</td>\n",
|
||
" <td>1.188473</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>1.208011</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>79.681298</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.017825</td>\n",
|
||
" <td>0.093756</td>\n",
|
||
" <td>0.213899</td>\n",
|
||
" <td>1.125077</td>\n",
|
||
" <td>-0.029812</td>\n",
|
||
" <td>-0.760753</td>\n",
|
||
" <td>1.145230</td>\n",
|
||
" <td>-0.753373</td>\n",
|
||
" <td>79.235449</td>\n",
|
||
" <td>-5.350188</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.016827</td>\n",
|
||
" <td>0.089198</td>\n",
|
||
" <td>0.201924</td>\n",
|
||
" <td>1.070371</td>\n",
|
||
" <td>-0.143699</td>\n",
|
||
" <td>-0.656466</td>\n",
|
||
" <td>1.089251</td>\n",
|
||
" <td>-0.671740</td>\n",
|
||
" <td>79.316807</td>\n",
|
||
" <td>0.976297</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.015655</td>\n",
|
||
" <td>0.086825</td>\n",
|
||
" <td>0.187865</td>\n",
|
||
" <td>1.041902</td>\n",
|
||
" <td>-0.168701</td>\n",
|
||
" <td>-0.341637</td>\n",
|
||
" <td>1.058703</td>\n",
|
||
" <td>-0.366576</td>\n",
|
||
" <td>79.778828</td>\n",
|
||
" <td>5.544252</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.192955</td>\n",
|
||
" <td>0.083971</td>\n",
|
||
" <td>2.315463</td>\n",
|
||
" <td>1.007656</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>2.525221</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>23.517970</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.178128</td>\n",
|
||
" <td>0.072764</td>\n",
|
||
" <td>2.137542</td>\n",
|
||
" <td>0.873173</td>\n",
|
||
" <td>-2.135059</td>\n",
|
||
" <td>-1.613797</td>\n",
|
||
" <td>2.309007</td>\n",
|
||
" <td>-2.594562</td>\n",
|
||
" <td>22.219713</td>\n",
|
||
" <td>-15.579091</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.166113</td>\n",
|
||
" <td>0.065728</td>\n",
|
||
" <td>1.993352</td>\n",
|
||
" <td>0.788742</td>\n",
|
||
" <td>-1.730279</td>\n",
|
||
" <td>-1.013172</td>\n",
|
||
" <td>2.143727</td>\n",
|
||
" <td>-1.983366</td>\n",
|
||
" <td>21.588019</td>\n",
|
||
" <td>-7.580324</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.160582</td>\n",
|
||
" <td>0.066971</td>\n",
|
||
" <td>1.926987</td>\n",
|
||
" <td>0.803649</td>\n",
|
||
" <td>-0.796376</td>\n",
|
||
" <td>0.178886</td>\n",
|
||
" <td>2.087853</td>\n",
|
||
" <td>-0.670484</td>\n",
|
||
" <td>22.638547</td>\n",
|
||
" <td>12.606340</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 25 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff ... dx dy vx \\\n",
|
||
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
|
||
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
|
||
"... ... ... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
|
||
"\n",
|
||
" vy ax ay v a heading d_heading \n",
|
||
"0 NaN NaN NaN NaN NaN NaN NaN \n",
|
||
"1 1.188473 NaN NaN 1.208011 NaN 79.681298 NaN \n",
|
||
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
|
||
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
|
||
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"320183 NaN NaN NaN NaN NaN NaN NaN \n",
|
||
"320184 1.007656 NaN NaN 2.525221 NaN 23.517970 NaN \n",
|
||
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
|
||
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
|
||
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
|
||
"\n",
|
||
"[320188 rows x 25 columns]"
|
||
]
|
||
},
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# then we need the velocity itself\n",
|
||
"data['v'] = np.sqrt(data['vx'].pow(2) + data['vy'].pow(2))\n",
|
||
"# and derive acceleration\n",
|
||
"data['a'] = data.groupby(['track_id'])['v'].diff().div(data['dt'], axis=0)\n",
|
||
"\n",
|
||
"# we can calculate heading based on the velocity components\n",
|
||
"data['heading'] = (np.arctan2(data['vy'], data['vx']) * 180 / np.pi) % 360\n",
|
||
"\n",
|
||
"# and derive it to get the rate of change of the heading\n",
|
||
"data['d_heading'] = data.groupby(['track_id'])['heading'].diff().div(data['dt'], axis=0)\n",
|
||
"\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>frame_id</th>\n",
|
||
" <th>track_id</th>\n",
|
||
" <th>l</th>\n",
|
||
" <th>t</th>\n",
|
||
" <th>w</th>\n",
|
||
" <th>h</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>diff</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>dx</th>\n",
|
||
" <th>dy</th>\n",
|
||
" <th>vx</th>\n",
|
||
" <th>vy</th>\n",
|
||
" <th>ax</th>\n",
|
||
" <th>ay</th>\n",
|
||
" <th>v</th>\n",
|
||
" <th>a</th>\n",
|
||
" <th>heading</th>\n",
|
||
" <th>d_heading</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.566162</td>\n",
|
||
" <td>88.795326</td>\n",
|
||
" <td>173.917542</td>\n",
|
||
" <td>0.855100</td>\n",
|
||
" <td>7.136193</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>1.208011</td>\n",
|
||
" <td>-0.753373</td>\n",
|
||
" <td>79.681298</td>\n",
|
||
" <td>-5.350188</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>565.116699</td>\n",
|
||
" <td>88.801704</td>\n",
|
||
" <td>171.334290</td>\n",
|
||
" <td>0.873132</td>\n",
|
||
" <td>7.235233</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.018032</td>\n",
|
||
" <td>0.099039</td>\n",
|
||
" <td>0.216383</td>\n",
|
||
" <td>1.188473</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>1.208011</td>\n",
|
||
" <td>-0.753373</td>\n",
|
||
" <td>79.681298</td>\n",
|
||
" <td>-5.350188</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>11.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874573</td>\n",
|
||
" <td>90.596596</td>\n",
|
||
" <td>177.199951</td>\n",
|
||
" <td>0.890957</td>\n",
|
||
" <td>7.328989</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.017825</td>\n",
|
||
" <td>0.093756</td>\n",
|
||
" <td>0.213899</td>\n",
|
||
" <td>1.125077</td>\n",
|
||
" <td>-0.029812</td>\n",
|
||
" <td>-0.760753</td>\n",
|
||
" <td>1.145230</td>\n",
|
||
" <td>-0.753373</td>\n",
|
||
" <td>79.235449</td>\n",
|
||
" <td>-5.350188</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>564.874268</td>\n",
|
||
" <td>90.928131</td>\n",
|
||
" <td>183.125732</td>\n",
|
||
" <td>0.907784</td>\n",
|
||
" <td>7.418187</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.016827</td>\n",
|
||
" <td>0.089198</td>\n",
|
||
" <td>0.201924</td>\n",
|
||
" <td>1.070371</td>\n",
|
||
" <td>-0.143699</td>\n",
|
||
" <td>-0.656466</td>\n",
|
||
" <td>1.089251</td>\n",
|
||
" <td>-0.671740</td>\n",
|
||
" <td>79.316807</td>\n",
|
||
" <td>0.976297</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>569.931213</td>\n",
|
||
" <td>86.213280</td>\n",
|
||
" <td>180.774292</td>\n",
|
||
" <td>0.923439</td>\n",
|
||
" <td>7.505012</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.015655</td>\n",
|
||
" <td>0.086825</td>\n",
|
||
" <td>0.187865</td>\n",
|
||
" <td>1.041902</td>\n",
|
||
" <td>-0.168701</td>\n",
|
||
" <td>-0.341637</td>\n",
|
||
" <td>1.058703</td>\n",
|
||
" <td>-0.366576</td>\n",
|
||
" <td>79.778828</td>\n",
|
||
" <td>5.544252</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320183</th>\n",
|
||
" <td>60159.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1830.709717</td>\n",
|
||
" <td>651.257446</td>\n",
|
||
" <td>150.202515</td>\n",
|
||
" <td>157.239746</td>\n",
|
||
" <td>14.840476</td>\n",
|
||
" <td>9.786501</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>2.525221</td>\n",
|
||
" <td>-2.594562</td>\n",
|
||
" <td>23.517970</td>\n",
|
||
" <td>-15.579091</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320184</th>\n",
|
||
" <td>60160.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1834.013672</td>\n",
|
||
" <td>649.612122</td>\n",
|
||
" <td>153.686646</td>\n",
|
||
" <td>160.874023</td>\n",
|
||
" <td>15.033432</td>\n",
|
||
" <td>9.870472</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.192955</td>\n",
|
||
" <td>0.083971</td>\n",
|
||
" <td>2.315463</td>\n",
|
||
" <td>1.007656</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>2.525221</td>\n",
|
||
" <td>-2.594562</td>\n",
|
||
" <td>23.517970</td>\n",
|
||
" <td>-15.579091</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320185</th>\n",
|
||
" <td>60161.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1845.373047</td>\n",
|
||
" <td>651.249756</td>\n",
|
||
" <td>147.178589</td>\n",
|
||
" <td>153.729248</td>\n",
|
||
" <td>15.211560</td>\n",
|
||
" <td>9.943236</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.178128</td>\n",
|
||
" <td>0.072764</td>\n",
|
||
" <td>2.137542</td>\n",
|
||
" <td>0.873173</td>\n",
|
||
" <td>-2.135059</td>\n",
|
||
" <td>-1.613797</td>\n",
|
||
" <td>2.309007</td>\n",
|
||
" <td>-2.594562</td>\n",
|
||
" <td>22.219713</td>\n",
|
||
" <td>-15.579091</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320186</th>\n",
|
||
" <td>60162.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1857.388916</td>\n",
|
||
" <td>650.908203</td>\n",
|
||
" <td>136.407349</td>\n",
|
||
" <td>142.354614</td>\n",
|
||
" <td>15.377673</td>\n",
|
||
" <td>10.008965</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.166113</td>\n",
|
||
" <td>0.065728</td>\n",
|
||
" <td>1.993352</td>\n",
|
||
" <td>0.788742</td>\n",
|
||
" <td>-1.730279</td>\n",
|
||
" <td>-1.013172</td>\n",
|
||
" <td>2.143727</td>\n",
|
||
" <td>-1.983366</td>\n",
|
||
" <td>21.588019</td>\n",
|
||
" <td>-7.580324</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320187</th>\n",
|
||
" <td>60163.0</td>\n",
|
||
" <td>3632</td>\n",
|
||
" <td>1862.792725</td>\n",
|
||
" <td>658.719971</td>\n",
|
||
" <td>141.984253</td>\n",
|
||
" <td>149.052307</td>\n",
|
||
" <td>15.538255</td>\n",
|
||
" <td>10.075935</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.160582</td>\n",
|
||
" <td>0.066971</td>\n",
|
||
" <td>1.926987</td>\n",
|
||
" <td>0.803649</td>\n",
|
||
" <td>-0.796376</td>\n",
|
||
" <td>0.178886</td>\n",
|
||
" <td>2.087853</td>\n",
|
||
" <td>-0.670484</td>\n",
|
||
" <td>22.638547</td>\n",
|
||
" <td>12.606340</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>320188 rows × 25 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" frame_id track_id l t w h \\\n",
|
||
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
||
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
||
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
||
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
||
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
||
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
||
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
||
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
||
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
||
"\n",
|
||
" x y state diff ... dx dy vx \\\n",
|
||
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
|
||
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
|
||
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
|
||
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
|
||
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
|
||
"... ... ... ... ... ... ... ... ... \n",
|
||
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
|
||
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
|
||
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
|
||
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
|
||
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
|
||
"\n",
|
||
" vy ax ay v a heading d_heading \n",
|
||
"0 NaN NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
|
||
"1 1.188473 NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
|
||
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
|
||
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
|
||
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"320183 NaN NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
|
||
"320184 1.007656 NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
|
||
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
|
||
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
|
||
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
|
||
"\n",
|
||
"[320188 rows x 25 columns]"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# we can backfill the v and a, so that our model can make estimations\n",
|
||
"# based on these assumed values\n",
|
||
"data['v'] = data.groupby(['track_id'])['v'].bfill()\n",
|
||
"data['a'] = data.groupby(['track_id'])['a'].bfill()\n",
|
||
"\n",
|
||
"data['heading'] = data.groupby(['track_id'])['heading'].bfill()\n",
|
||
"data['d_heading'] = data.groupby(['track_id'])['d_heading'].bfill()\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"312423 items in filtered set, out of 320188 in total set\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"filtered_data = data.groupby(['track_id']).filter(lambda group: len(group) >= window+1) # a lenght of 3 is neccessary to have all relevant derivatives of position\n",
|
||
"filtered_data = filtered_data.set_index(['track_id', 'frame_id']) # use for quick access\n",
|
||
"print(filtered_data.shape[0], \"items in filtered set, out of\", data.shape[0], \"in total set\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"1263 training tracks, 316 test tracks\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"track_ids = filtered_data.index.unique('track_id').to_numpy()\n",
|
||
"np.random.shuffle(track_ids)\n",
|
||
"test_offset_idx = int(len(track_ids) * .8)\n",
|
||
"training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n",
|
||
"print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"here, draw out a sample track to see if it looks alright. **unfortunately the imate isn't mapped properly**."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"1058\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"0 0\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import random\n",
|
||
"if H:\n",
|
||
" img_src = \"../DATASETS/hof/webcam20231103-2.png\"\n",
|
||
" # dst = cv2.warpPerspective(img_src,H,(2500,1920))\n",
|
||
" src_img = cv2.imread(img_src)\n",
|
||
" print(src_img.shape)\n",
|
||
" h1,w1 = src_img.shape[:2]\n",
|
||
" corners = np.float32([[0,0], [w1, 0], [0, h1], [w1, h1]])\n",
|
||
"\n",
|
||
" print(corners)\n",
|
||
" corners_projected = cv2.perspectiveTransform(corners.reshape((-1,4,2)), H)[0]\n",
|
||
" print(corners_projected)\n",
|
||
" [xmin, ymin] = np.int32(corners_projected.min(axis=0).ravel() - 0.5)\n",
|
||
" [xmax, ymax] = np.int32(corners_projected.max(axis=0).ravel() + 0.5)\n",
|
||
" print(xmin, xmax, ymin, ymax)\n",
|
||
"\n",
|
||
" dst = cv2.warpPerspective(src_img,H, (xmax, ymax))\n",
|
||
" def plot_track(track_id: int):\n",
|
||
" plt.gca().invert_yaxis()\n",
|
||
"\n",
|
||
" plt.imshow(dst, origin='lower', extent=[xmin/100-mean_x, xmax/100-mean_x, ymin/100-mean_y, ymax/100-mean_y])\n",
|
||
" # plot scatter plot with x and y data \n",
|
||
" \n",
|
||
" ax = plt.scatter(\n",
|
||
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[track_id,:]['proj_y'],\n",
|
||
" marker=\"*\") \n",
|
||
" ax.axes.invert_yaxis()\n",
|
||
" plt.plot(\n",
|
||
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[track_id,:]['proj_y']\n",
|
||
" )\n",
|
||
"else:\n",
|
||
" def plot_track(track_id: int):\n",
|
||
" ax = plt.scatter(\n",
|
||
" filtered_data.loc[track_id,:]['x'],\n",
|
||
" filtered_data.loc[track_id,:]['y'],\n",
|
||
" marker=\"*\") \n",
|
||
" plt.plot(\n",
|
||
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[track_id,:]['proj_y']\n",
|
||
" )\n",
|
||
"\n",
|
||
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
||
"# _track_id = 2188\n",
|
||
"_track_id = random.choice(track_ids)\n",
|
||
"print(_track_id)\n",
|
||
"plot_track(_track_id)\n",
|
||
"\n",
|
||
"for track_id in random.choices(track_ids, k=100):\n",
|
||
" plot_track(track_id)\n",
|
||
" \n",
|
||
"print(mean_x, mean_y)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now make the dataset:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# a=filtered_data.loc[1]\n",
|
||
"# min(a.index.tolist())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 0%| | 0/1263 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1263/1263 [01:42<00:00, 12.36it/s]\n",
|
||
"100%|██████████| 316/316 [00:26<00:00, 11.75it/s]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"def create_dataset(data, track_ids, window):\n",
|
||
" X, y, = [], []\n",
|
||
" for track_id in tqdm(track_ids):\n",
|
||
" df = data.loc[track_id]\n",
|
||
" start_frame = min(df.index.tolist())\n",
|
||
" for step in range(len(df)-window-1):\n",
|
||
" i = int(start_frame) + step\n",
|
||
" # print(step, int(start_frame), i)\n",
|
||
" feature = df.loc[i:i+window][in_fields]\n",
|
||
" # target = df.loc[i+1:i+window+1][out_fields]\n",
|
||
" target = df.loc[i+window+1][out_fields]\n",
|
||
" X.append(feature.values)\n",
|
||
" y.append(target.values)\n",
|
||
" \n",
|
||
" return torch.tensor(np.array(X), dtype=torch.float), torch.tensor(np.array(y), dtype=torch.float)\n",
|
||
"\n",
|
||
"X_train, y_train = create_dataset(filtered_data, training_ids, window)\n",
|
||
"X_test, y_test = create_dataset(filtered_data, test_ids, window)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train, y_train = X_train.to(device=device), y_train.to(device=device)\n",
|
||
"X_test, y_test = X_test.to(device=device), y_test.to(device=device)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from torch.utils.data import TensorDataset, DataLoader\n",
|
||
"dataset_train = TensorDataset(X_train, y_train)\n",
|
||
"loader_train = DataLoader(dataset_train, shuffle=True, batch_size=8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Model give output for all timesteps, this should improve training. But we use only the last timestep for the prediction process"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LSTMModel(nn.Module):\n",
|
||
" # input_size : number of features in input at each time step\n",
|
||
" # hidden_size : Number of LSTM units \n",
|
||
" # num_layers : number of LSTM layers \n",
|
||
" def __init__(self, input_size, hidden_size, num_layers): \n",
|
||
" super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n",
|
||
" self.lin1 = nn.Linear(input_size, hidden_size//2)\n",
|
||
" self.lstm = nn.LSTM(hidden_size//2, hidden_size, num_layers, batch_first=True)\n",
|
||
" self.linear = nn.Linear(hidden_size, output_size)\n",
|
||
" # self.activation_v = nn.LeakyReLU(.01)\n",
|
||
" # self.activation_heading = torch.remainder()\n",
|
||
"\n",
|
||
" def forward(self, x): # defines forward pass of the neural network\n",
|
||
" out = self.lin1(x)\n",
|
||
" out, h0 = self.lstm(out)\n",
|
||
" # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n",
|
||
" # print(out.shape)\n",
|
||
" out = out[:, -1,:]\n",
|
||
" # print(out.shape)\n",
|
||
" out = self.linear(out)\n",
|
||
" \n",
|
||
" # torch.remainder(out[1], 360)\n",
|
||
" # print('o',out.shape)\n",
|
||
" return out\n",
|
||
"\n",
|
||
"model = LSTMModel(input_size, hidden_size, num_layers).to(device)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||
"loss_fn = nn.MSELoss()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def evaluate():\n",
|
||
" # toggle evaluation mode\n",
|
||
" model.eval()\n",
|
||
" with torch.no_grad():\n",
|
||
" y_pred = model(X_train.to(device=device))\n",
|
||
" train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n",
|
||
" y_pred = model(X_test.to(device=device))\n",
|
||
" test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n",
|
||
" print(\"Epoch %d: train RMSE %.4f, test RMSE %.4f\" % (epoch, train_rmse, test_rmse))\n",
|
||
"\n",
|
||
"def load_most_recent():\n",
|
||
" paths = list(cache_path.glob(\"checkpoint_*.pt\"))\n",
|
||
" if len(paths) < 1:\n",
|
||
" print('Nothing found to load')\n",
|
||
" return None, None\n",
|
||
" paths.sort()\n",
|
||
"\n",
|
||
" print(f\"Loading {paths[-1]}\")\n",
|
||
" return load_cache(path=paths[-1])\n",
|
||
"\n",
|
||
"def load_cache(epoch=None, path=None):\n",
|
||
" if path is None:\n",
|
||
" if epoch is None:\n",
|
||
" raise RuntimeError(\"Either path or epoch must be given\")\n",
|
||
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
|
||
" else:\n",
|
||
" print (path.stem)\n",
|
||
" epoch = int(path.stem[-5:])\n",
|
||
"\n",
|
||
" cached = torch.load(path)\n",
|
||
" \n",
|
||
" optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
|
||
" model.load_state_dict(cached['model_state_dict'])\n",
|
||
" return epoch, cached['loss']\n",
|
||
" \n",
|
||
"\n",
|
||
"def cache(epoch, loss):\n",
|
||
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
|
||
" print(f\"Cache to {path}\")\n",
|
||
" torch.save({\n",
|
||
" 'epoch': epoch,\n",
|
||
" 'model_state_dict': model.state_dict(),\n",
|
||
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
||
" 'loss': loss,\n",
|
||
" }, path)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Loading EXPERIMENTS/cache/hof2/checkpoint_00005.pt\n",
|
||
"checkpoint_00005\n",
|
||
"starting from epoch 5 (loss: nan)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 0%| | 0/95 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 4%|▍ | 4/95 [07:37<2:53:27, 114.37s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Cache to EXPERIMENTS/cache/hof2/checkpoint_00010.pt\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "RuntimeError",
|
||
"evalue": "CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[31], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 30\u001b[0m cache(epoch, loss)\n\u001b[0;32m---> 31\u001b[0m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m evaluate()\n",
|
||
"Cell \u001b[0;32mIn[30], line 5\u001b[0m, in \u001b[0;36mevaluate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m train_rmse \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msqrt(loss_fn(y_pred, y_train))\n\u001b[1;32m 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_test\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
|
||
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||
"Cell \u001b[0;32mIn[28], line 15\u001b[0m, in \u001b[0;36mLSTMModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x): \u001b[38;5;66;03m# defines forward pass of the neural network\u001b[39;00m\n\u001b[1;32m 14\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin1(x)\n\u001b[0;32m---> 15\u001b[0m out, h0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# print(out.shape)\u001b[39;00m\n\u001b[1;32m 18\u001b[0m out \u001b[38;5;241m=\u001b[39m out[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n",
|
||
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:769\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_forward_args(\u001b[38;5;28minput\u001b[39m, hx, batch_sizes)\n\u001b[1;32m 768\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 769\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 772\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\u001b[38;5;28minput\u001b[39m, batch_sizes, hx, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_weights, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias,\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional)\n",
|
||
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"start_epoch, loss = load_most_recent()\n",
|
||
"if start_epoch is None:\n",
|
||
" start_epoch = 0\n",
|
||
"else:\n",
|
||
" print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n",
|
||
"\n",
|
||
"# Train Network\n",
|
||
"for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n",
|
||
" # toggle train mode\n",
|
||
" model.train()\n",
|
||
" for batch_idx, (data, targets) in enumerate(loader_train):\n",
|
||
" # Get data to cuda if possible\n",
|
||
" data = data.to(device=device).squeeze(1)\n",
|
||
" targets = targets.to(device=device)\n",
|
||
"\n",
|
||
" # forward\n",
|
||
" scores = model(data)\n",
|
||
" loss = loss_fn(scores, targets)\n",
|
||
"\n",
|
||
" # backward\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
"\n",
|
||
" # gradient descent update step/adam step\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" if epoch % 5 != 0:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" cache(epoch, loss)\n",
|
||
" evaluate()\n",
|
||
"\n",
|
||
"evaluate()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "NameError",
|
||
"evalue": "name 'model' is not defined",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 3\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_train\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
|
||
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model.eval()\n",
|
||
"with torch.no_grad():\n",
|
||
" y_pred = model(X_train.to(device=device))\n",
|
||
" print(y_pred.shape, y_train.shape)\n",
|
||
"y_train, y_pred"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(0, 0)"
|
||
]
|
||
},
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"mean_x, mean_y"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import scipy\n",
|
||
"\n",
|
||
"\n",
|
||
"def predict_and_plot(feature, steps = 50):\n",
|
||
" lenght = feature.shape[0]\n",
|
||
" # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n",
|
||
" # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n",
|
||
" with torch.no_grad():\n",
|
||
" for i in range(steps):\n",
|
||
" # predict_f = scipy.ndimage.uniform_filter(feature)\n",
|
||
" # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n",
|
||
" # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n",
|
||
" # bathc size of one, so feature as single item in array\n",
|
||
" X = torch.tensor([feature], dtype=torch.float).to(device)\n",
|
||
" # print(X.shape)\n",
|
||
" s = model(X)[0].cpu()\n",
|
||
" \n",
|
||
" # proj_x proj_y v heading a d_heading\n",
|
||
" # next_step = feature\n",
|
||
" dt = 1/ FPS\n",
|
||
" v = np.sqrt(s[0]**2 + s[1]**2)\n",
|
||
" heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n",
|
||
" a = (v - feature[-1][2]) / dt\n",
|
||
" d_heading = (heading - feature[-1][5])\n",
|
||
" # print(s)\n",
|
||
" feature = np.append(feature, [[feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]], axis=0)\n",
|
||
" \n",
|
||
" # print(next_step, nxt)\n",
|
||
" plt.plot(feature[:lenght,0], feature[:lenght,1], c='orange')\n",
|
||
" plt.plot(feature[lenght-1:,0], feature[lenght-1:,1], c='red')\n",
|
||
" plt.scatter(feature[lenght:,0], feature[lenght:,1], c='red')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2515\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
||
"_track_id = 8701 # random.choice(track_ids)\n",
|
||
"_track_id = 3880 # random.choice(track_ids)\n",
|
||
"\n",
|
||
"# _track_id = 2780\n",
|
||
"\n",
|
||
"for i in range(100):\n",
|
||
" _track_id = random.choice(track_ids)\n",
|
||
" plt.plot(\n",
|
||
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[_track_id,:]['proj_y'],\n",
|
||
" c='grey', alpha=.2\n",
|
||
" )\n",
|
||
"\n",
|
||
"_track_id = random.choice(track_ids)\n",
|
||
"# _track_id = 801\n",
|
||
"print(_track_id)\n",
|
||
"ax = plt.scatter(\n",
|
||
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[_track_id,:]['proj_y'],\n",
|
||
" marker=\"*\") \n",
|
||
"plt.plot(\n",
|
||
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
||
" filtered_data.loc[_track_id,:]['proj_y']\n",
|
||
")\n",
|
||
"\n",
|
||
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
|
||
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:10][in_fields].values)\n",
|
||
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:50][in_fields].values)\n",
|
||
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
|
||
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|