trap/test_custom_rnn.ipynb

3625 lines
294 KiB
Text
Raw Normal View History

2024-11-17 19:39:32 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ruben/suspicion/trap/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt # Visualization \n",
"import torch.nn as nn\n",
"import pandas_helper_calc # noqa # provides df.calc.derivative()\n",
"import pandas as pd\n",
"import cv2\n",
"import pathlib\n",
"from tqdm.autonotebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"FPS = 12\n",
"# SRC_CSV = \"EXPERIMENTS/hofext-maskrcnn/all.txt\"\n",
"# SRC_CSV = \"EXPERIMENTS/raw/generated/train/tracks.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/raw/hof-meter-maskrcnn2/train/tracks.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/20240426-hof-yolo/train/tracked.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/raw/hof2/train/tracked.txt\"\n",
"# SRC_H = \"../DATASETS/hof/webcam20231103-2-homography.txt\"\n",
"SRC_H = None\n",
"CACHE_DIR = \"EXPERIMENTS/cache/hof2/\"\n",
"SMOOTHING = True # hof-yolo is already smoothed, hof2 isn't\n",
"SMOOTHING_WINDOW=3 #2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"in_fields = ['proj_x', 'proj_y', 'vx', 'vy', 'ax', 'ay']\n",
"# out_fields = ['v', 'heading']\n",
"# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n",
"# an we can use simple MSE loss (I guess?)\n",
"out_fields = ['vx', 'vy']\n",
"window = int(FPS*1.5)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# Set device\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(device)\n",
"\n",
"# Hyperparameters\n",
"input_size = len(in_fields)\n",
"hidden_size = 256\n",
"num_layers = 3\n",
"output_size = len(out_fields)\n",
"learning_rate = 0.005 #0.01 #0.005\n",
"batch_size = 256\n",
"num_epochs = 1000"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"cache_path = pathlib.Path(CACHE_DIR)\n",
"cache_path.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Samping 1/5, of 412098 items\n",
"Done sampling kept 83726 items\n"
]
}
],
"source": [
"from pathlib import Path\n",
"from trap.tools import load_tracks_from_csv\n",
"\n",
"data = load_tracks_from_csv(Path(SRC_CSV), FPS, 2, 5 )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>...</th>\n",
" <th>dx</th>\n",
" <th>dy</th>\n",
" <th>vx</th>\n",
" <th>vy</th>\n",
" <th>ax</th>\n",
" <th>ay</th>\n",
" <th>v</th>\n",
" <th>a</th>\n",
" <th>heading</th>\n",
" <th>d_heading</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>194</th>\n",
" <td>606.0</td>\n",
" <td>4</td>\n",
" <td>1593.885864</td>\n",
" <td>782.814819</td>\n",
" <td>145.704346</td>\n",
" <td>195.380432</td>\n",
" <td>12.897830</td>\n",
" <td>10.750061</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>0.201965</td>\n",
" <td>-0.291350</td>\n",
" <td>0.484716</td>\n",
" <td>-0.699240</td>\n",
" <td>-1.622919</td>\n",
" <td>-1.732144</td>\n",
" <td>0.850815</td>\n",
" <td>1.399195</td>\n",
" <td>304.729842</td>\n",
" <td>-101.772559</td>\n",
" </tr>\n",
" <tr>\n",
" <th>199</th>\n",
" <td>611.0</td>\n",
" <td>4</td>\n",
" <td>1563.890015</td>\n",
" <td>700.710510</td>\n",
" <td>137.461304</td>\n",
" <td>190.194855</td>\n",
" <td>13.099794</td>\n",
" <td>10.458712</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>0.201965</td>\n",
" <td>-0.291350</td>\n",
" <td>0.484716</td>\n",
" <td>-0.699240</td>\n",
" <td>-1.622919</td>\n",
" <td>-1.732144</td>\n",
" <td>0.850815</td>\n",
" <td>1.399195</td>\n",
" <td>304.729842</td>\n",
" <td>-101.772559</td>\n",
" </tr>\n",
" <tr>\n",
" <th>204</th>\n",
" <td>616.0</td>\n",
" <td>4</td>\n",
" <td>1529.469727</td>\n",
" <td>635.622498</td>\n",
" <td>129.342651</td>\n",
" <td>194.191528</td>\n",
" <td>13.020002</td>\n",
" <td>9.866642</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.079792</td>\n",
" <td>-0.592069</td>\n",
" <td>-0.191501</td>\n",
" <td>-1.420966</td>\n",
" <td>-1.622919</td>\n",
" <td>-1.732144</td>\n",
" <td>1.433812</td>\n",
" <td>1.399195</td>\n",
" <td>262.324609</td>\n",
" <td>-101.772559</td>\n",
" </tr>\n",
" <tr>\n",
" <th>209</th>\n",
" <td>621.0</td>\n",
" <td>4</td>\n",
" <td>1474.449341</td>\n",
" <td>569.387634</td>\n",
" <td>128.099854</td>\n",
" <td>199.766357</td>\n",
" <td>12.965776</td>\n",
" <td>9.301442</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.054226</td>\n",
" <td>-0.565200</td>\n",
" <td>-0.130143</td>\n",
" <td>-1.356479</td>\n",
" <td>0.147259</td>\n",
" <td>0.154769</td>\n",
" <td>1.362708</td>\n",
" <td>-0.170650</td>\n",
" <td>264.519715</td>\n",
" <td>5.268254</td>\n",
" </tr>\n",
" <tr>\n",
" <th>214</th>\n",
" <td>626.0</td>\n",
" <td>4</td>\n",
" <td>1443.123535</td>\n",
" <td>518.907043</td>\n",
" <td>120.022461</td>\n",
" <td>202.566772</td>\n",
" <td>12.642992</td>\n",
" <td>8.976624</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.322784</td>\n",
" <td>-0.324818</td>\n",
" <td>-0.774681</td>\n",
" <td>-0.779564</td>\n",
" <td>-1.546892</td>\n",
" <td>1.384597</td>\n",
" <td>1.099023</td>\n",
" <td>-0.632844</td>\n",
" <td>225.179993</td>\n",
" <td>-94.415332</td>\n",
" </tr>\n",
" <tr>\n",
" <th>219</th>\n",
" <td>631.0</td>\n",
" <td>4</td>\n",
" <td>1398.944946</td>\n",
" <td>461.813049</td>\n",
" <td>106.391357</td>\n",
" <td>193.476410</td>\n",
" <td>12.465588</td>\n",
" <td>8.557788</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.177404</td>\n",
" <td>-0.418836</td>\n",
" <td>-0.425771</td>\n",
" <td>-1.005205</td>\n",
" <td>0.837386</td>\n",
" <td>-0.541539</td>\n",
" <td>1.091659</td>\n",
" <td>-0.017675</td>\n",
" <td>247.044148</td>\n",
" <td>52.473972</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224</th>\n",
" <td>636.0</td>\n",
" <td>4</td>\n",
" <td>1353.237793</td>\n",
" <td>438.118896</td>\n",
" <td>91.444336</td>\n",
" <td>170.930664</td>\n",
" <td>12.128433</td>\n",
" <td>8.052323</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.337155</td>\n",
" <td>-0.505465</td>\n",
" <td>-0.809172</td>\n",
" <td>-1.213117</td>\n",
" <td>-0.920163</td>\n",
" <td>-0.498987</td>\n",
" <td>1.458222</td>\n",
" <td>0.879752</td>\n",
" <td>236.295957</td>\n",
" <td>-25.795658</td>\n",
" </tr>\n",
" <tr>\n",
" <th>229</th>\n",
" <td>641.0</td>\n",
" <td>4</td>\n",
" <td>1272.791992</td>\n",
" <td>408.827759</td>\n",
" <td>104.274536</td>\n",
" <td>180.414551</td>\n",
" <td>11.689648</td>\n",
" <td>7.684636</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.438785</td>\n",
" <td>-0.367687</td>\n",
" <td>-1.053084</td>\n",
" <td>-0.882448</td>\n",
" <td>-0.585388</td>\n",
" <td>0.793604</td>\n",
" <td>1.373936</td>\n",
" <td>-0.202286</td>\n",
" <td>219.961870</td>\n",
" <td>-39.201809</td>\n",
" </tr>\n",
" <tr>\n",
" <th>234</th>\n",
" <td>646.0</td>\n",
" <td>4</td>\n",
" <td>1198.965820</td>\n",
" <td>407.952759</td>\n",
" <td>103.282104</td>\n",
" <td>167.306580</td>\n",
" <td>11.207276</td>\n",
" <td>7.476216</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.482372</td>\n",
" <td>-0.208420</td>\n",
" <td>-1.157693</td>\n",
" <td>-0.500209</td>\n",
" <td>-0.251064</td>\n",
" <td>0.917374</td>\n",
" <td>1.261136</td>\n",
" <td>-0.270721</td>\n",
" <td>203.367915</td>\n",
" <td>-39.825493</td>\n",
" </tr>\n",
" <tr>\n",
" <th>239</th>\n",
" <td>651.0</td>\n",
" <td>4</td>\n",
" <td>1156.309570</td>\n",
" <td>415.743408</td>\n",
" <td>97.628784</td>\n",
" <td>158.774811</td>\n",
" <td>10.884154</td>\n",
" <td>7.514692</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.323122</td>\n",
" <td>0.038476</td>\n",
" <td>-0.775493</td>\n",
" <td>0.092343</td>\n",
" <td>0.917282</td>\n",
" <td>1.422125</td>\n",
" <td>0.780971</td>\n",
" <td>-1.152395</td>\n",
" <td>173.209381</td>\n",
" <td>-72.380481</td>\n",
" </tr>\n",
" <tr>\n",
" <th>244</th>\n",
" <td>656.0</td>\n",
" <td>4</td>\n",
" <td>1094.440430</td>\n",
" <td>443.849915</td>\n",
" <td>107.938110</td>\n",
" <td>177.703979</td>\n",
" <td>10.544492</td>\n",
" <td>7.870090</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.339661</td>\n",
" <td>0.355398</td>\n",
" <td>-0.815187</td>\n",
" <td>0.852955</td>\n",
" <td>-0.095267</td>\n",
" <td>1.825468</td>\n",
" <td>1.179857</td>\n",
" <td>0.957326</td>\n",
" <td>133.703018</td>\n",
" <td>-94.815270</td>\n",
" </tr>\n",
" <tr>\n",
" <th>249</th>\n",
" <td>661.0</td>\n",
" <td>4</td>\n",
" <td>1072.595093</td>\n",
" <td>481.461945</td>\n",
" <td>118.452148</td>\n",
" <td>205.365173</td>\n",
" <td>10.486504</td>\n",
" <td>8.287758</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.057989</td>\n",
" <td>0.417668</td>\n",
" <td>-0.139173</td>\n",
" <td>1.002404</td>\n",
" <td>1.622435</td>\n",
" <td>0.358678</td>\n",
" <td>1.012019</td>\n",
" <td>-0.402811</td>\n",
" <td>97.904355</td>\n",
" <td>-85.916792</td>\n",
" </tr>\n",
" <tr>\n",
" <th>254</th>\n",
" <td>666.0</td>\n",
" <td>4</td>\n",
" <td>1086.627930</td>\n",
" <td>526.733154</td>\n",
" <td>105.444458</td>\n",
" <td>189.750610</td>\n",
" <td>10.498393</td>\n",
" <td>8.684043</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>0.011889</td>\n",
" <td>0.396285</td>\n",
" <td>0.028534</td>\n",
" <td>0.951083</td>\n",
" <td>0.402496</td>\n",
" <td>-0.123170</td>\n",
" <td>0.951511</td>\n",
" <td>-0.145220</td>\n",
" <td>88.281546</td>\n",
" <td>-23.094741</td>\n",
" </tr>\n",
" <tr>\n",
" <th>259</th>\n",
" <td>671.0</td>\n",
" <td>4</td>\n",
" <td>1099.592285</td>\n",
" <td>584.216675</td>\n",
" <td>114.395874</td>\n",
" <td>218.003479</td>\n",
" <td>10.492767</td>\n",
" <td>9.267106</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.005626</td>\n",
" <td>0.583063</td>\n",
" <td>-0.013502</td>\n",
" <td>1.399352</td>\n",
" <td>-0.100887</td>\n",
" <td>1.075845</td>\n",
" <td>1.399417</td>\n",
" <td>1.074975</td>\n",
" <td>90.552815</td>\n",
" <td>5.451045</td>\n",
" </tr>\n",
" <tr>\n",
" <th>264</th>\n",
" <td>676.0</td>\n",
" <td>4</td>\n",
" <td>1144.484782</td>\n",
" <td>642.779582</td>\n",
" <td>96.750326</td>\n",
" <td>180.744690</td>\n",
" <td>10.484691</td>\n",
" <td>9.582745</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-0.008077</td>\n",
" <td>0.315639</td>\n",
" <td>-0.019384</td>\n",
" <td>0.757534</td>\n",
" <td>-0.014116</td>\n",
" <td>-1.540364</td>\n",
" <td>0.757782</td>\n",
" <td>-1.539925</td>\n",
" <td>91.465753</td>\n",
" <td>2.191052</td>\n",
" </tr>\n",
" <tr>\n",
" <th>269</th>\n",
" <td>681.0</td>\n",
" <td>4</td>\n",
" <td>1179.532959</td>\n",
" <td>682.365540</td>\n",
" <td>107.764282</td>\n",
" <td>200.651733</td>\n",
" <td>10.698373</td>\n",
" <td>9.950516</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>0.213682</td>\n",
" <td>0.367771</td>\n",
" <td>0.512837</td>\n",
" <td>0.882650</td>\n",
" <td>1.277331</td>\n",
" <td>0.300278</td>\n",
" <td>1.020820</td>\n",
" <td>0.631291</td>\n",
" <td>59.842534</td>\n",
" <td>-75.895726</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>16 rows × 24 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"194 606.0 4 1593.885864 782.814819 145.704346 195.380432 \n",
"199 611.0 4 1563.890015 700.710510 137.461304 190.194855 \n",
"204 616.0 4 1529.469727 635.622498 129.342651 194.191528 \n",
"209 621.0 4 1474.449341 569.387634 128.099854 199.766357 \n",
"214 626.0 4 1443.123535 518.907043 120.022461 202.566772 \n",
"219 631.0 4 1398.944946 461.813049 106.391357 193.476410 \n",
"224 636.0 4 1353.237793 438.118896 91.444336 170.930664 \n",
"229 641.0 4 1272.791992 408.827759 104.274536 180.414551 \n",
"234 646.0 4 1198.965820 407.952759 103.282104 167.306580 \n",
"239 651.0 4 1156.309570 415.743408 97.628784 158.774811 \n",
"244 656.0 4 1094.440430 443.849915 107.938110 177.703979 \n",
"249 661.0 4 1072.595093 481.461945 118.452148 205.365173 \n",
"254 666.0 4 1086.627930 526.733154 105.444458 189.750610 \n",
"259 671.0 4 1099.592285 584.216675 114.395874 218.003479 \n",
"264 676.0 4 1144.484782 642.779582 96.750326 180.744690 \n",
"269 681.0 4 1179.532959 682.365540 107.764282 200.651733 \n",
"\n",
" x y state diff ... dx dy vx \\\n",
"194 12.897830 10.750061 2.0 NaN ... 0.201965 -0.291350 0.484716 \n",
"199 13.099794 10.458712 1.0 5.0 ... 0.201965 -0.291350 0.484716 \n",
"204 13.020002 9.866642 2.0 5.0 ... -0.079792 -0.592069 -0.191501 \n",
"209 12.965776 9.301442 2.0 5.0 ... -0.054226 -0.565200 -0.130143 \n",
"214 12.642992 8.976624 2.0 5.0 ... -0.322784 -0.324818 -0.774681 \n",
"219 12.465588 8.557788 2.0 5.0 ... -0.177404 -0.418836 -0.425771 \n",
"224 12.128433 8.052323 2.0 5.0 ... -0.337155 -0.505465 -0.809172 \n",
"229 11.689648 7.684636 2.0 5.0 ... -0.438785 -0.367687 -1.053084 \n",
"234 11.207276 7.476216 2.0 5.0 ... -0.482372 -0.208420 -1.157693 \n",
"239 10.884154 7.514692 2.0 5.0 ... -0.323122 0.038476 -0.775493 \n",
"244 10.544492 7.870090 2.0 5.0 ... -0.339661 0.355398 -0.815187 \n",
"249 10.486504 8.287758 2.0 5.0 ... -0.057989 0.417668 -0.139173 \n",
"254 10.498393 8.684043 2.0 5.0 ... 0.011889 0.396285 0.028534 \n",
"259 10.492767 9.267106 2.0 5.0 ... -0.005626 0.583063 -0.013502 \n",
"264 10.484691 9.582745 1.0 5.0 ... -0.008077 0.315639 -0.019384 \n",
"269 10.698373 9.950516 2.0 5.0 ... 0.213682 0.367771 0.512837 \n",
"\n",
" vy ax ay v a heading d_heading \n",
"194 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
"199 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
"204 -1.420966 -1.622919 -1.732144 1.433812 1.399195 262.324609 -101.772559 \n",
"209 -1.356479 0.147259 0.154769 1.362708 -0.170650 264.519715 5.268254 \n",
"214 -0.779564 -1.546892 1.384597 1.099023 -0.632844 225.179993 -94.415332 \n",
"219 -1.005205 0.837386 -0.541539 1.091659 -0.017675 247.044148 52.473972 \n",
"224 -1.213117 -0.920163 -0.498987 1.458222 0.879752 236.295957 -25.795658 \n",
"229 -0.882448 -0.585388 0.793604 1.373936 -0.202286 219.961870 -39.201809 \n",
"234 -0.500209 -0.251064 0.917374 1.261136 -0.270721 203.367915 -39.825493 \n",
"239 0.092343 0.917282 1.422125 0.780971 -1.152395 173.209381 -72.380481 \n",
"244 0.852955 -0.095267 1.825468 1.179857 0.957326 133.703018 -94.815270 \n",
"249 1.002404 1.622435 0.358678 1.012019 -0.402811 97.904355 -85.916792 \n",
"254 0.951083 0.402496 -0.123170 0.951511 -0.145220 88.281546 -23.094741 \n",
"259 1.399352 -0.100887 1.075845 1.399417 1.074975 90.552815 5.451045 \n",
"264 0.757534 -0.014116 -1.540364 0.757782 -1.539925 91.465753 2.191052 \n",
"269 0.882650 1.277331 0.300278 1.020820 0.631291 59.842534 -75.895726 \n",
"\n",
"[16 rows x 24 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" </tr>\n",
" <tr>\n",
" <th>track_id</th>\n",
" <th>frame_id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">1</th>\n",
" <th>342</th>\n",
" <td>1393.736572</td>\n",
" <td>0.000000</td>\n",
" <td>67.613647</td>\n",
" <td>121.391151</td>\n",
" <td>1363.3164</td>\n",
" <td>232.92647</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>343</th>\n",
" <td>1391.775879</td>\n",
" <td>0.852371</td>\n",
" <td>78.562622</td>\n",
" <td>141.050934</td>\n",
" <td>1359.1885</td>\n",
" <td>266.06586</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1392.164551</td>\n",
" <td>7.758987</td>\n",
" <td>85.757324</td>\n",
" <td>154.357971</td>\n",
" <td>1355.7444</td>\n",
" <td>297.67404</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>347</th>\n",
" <td>1393.844849</td>\n",
" <td>12.691238</td>\n",
" <td>86.482910</td>\n",
" <td>156.264786</td>\n",
" <td>1355.2312</td>\n",
" <td>308.20670</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>348</th>\n",
" <td>1394.839111</td>\n",
" <td>15.621338</td>\n",
" <td>84.763428</td>\n",
" <td>154.584396</td>\n",
" <td>1354.9246</td>\n",
" <td>310.09225</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">5030</th>\n",
" <th>32691</th>\n",
" <td>1708.213379</td>\n",
" <td>749.260376</td>\n",
" <td>133.839966</td>\n",
" <td>182.405396</td>\n",
" <td>1402.5426</td>\n",
" <td>1075.20870</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32692</th>\n",
" <td>1707.651855</td>\n",
" <td>748.997437</td>\n",
" <td>134.013672</td>\n",
" <td>182.391296</td>\n",
" <td>1402.2948</td>\n",
" <td>1074.97230</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32720</th>\n",
" <td>1700.379639</td>\n",
" <td>750.314697</td>\n",
" <td>128.792603</td>\n",
" <td>181.589783</td>\n",
" <td>1395.7992</td>\n",
" <td>1074.27320</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32721</th>\n",
" <td>1701.722412</td>\n",
" <td>751.000488</td>\n",
" <td>125.286865</td>\n",
" <td>180.867615</td>\n",
" <td>1395.5424</td>\n",
" <td>1074.20560</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32722</th>\n",
" <td>1702.384766</td>\n",
" <td>750.754517</td>\n",
" <td>123.435425</td>\n",
" <td>180.945618</td>\n",
" <td>1395.4082</td>\n",
" <td>1074.06500</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>326960 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" l t w h x \\\n",
"track_id frame_id \n",
"1 342 1393.736572 0.000000 67.613647 121.391151 1363.3164 \n",
" 343 1391.775879 0.852371 78.562622 141.050934 1359.1885 \n",
" 346 1392.164551 7.758987 85.757324 154.357971 1355.7444 \n",
" 347 1393.844849 12.691238 86.482910 156.264786 1355.2312 \n",
" 348 1394.839111 15.621338 84.763428 154.584396 1354.9246 \n",
"... ... ... ... ... ... \n",
"5030 32691 1708.213379 749.260376 133.839966 182.405396 1402.5426 \n",
" 32692 1707.651855 748.997437 134.013672 182.391296 1402.2948 \n",
" 32720 1700.379639 750.314697 128.792603 181.589783 1395.7992 \n",
" 32721 1701.722412 751.000488 125.286865 180.867615 1395.5424 \n",
" 32722 1702.384766 750.754517 123.435425 180.945618 1395.4082 \n",
"\n",
" y state \n",
"track_id frame_id \n",
"1 342 232.92647 2 \n",
" 343 266.06586 2 \n",
" 346 297.67404 2 \n",
" 347 308.20670 2 \n",
" 348 310.09225 2 \n",
"... ... ... \n",
"5030 32691 1075.20870 2 \n",
" 32692 1074.97230 2 \n",
" 32720 1074.27320 2 \n",
" 32721 1074.20560 2 \n",
" 32722 1074.06500 2 \n",
"\n",
"[326960 rows x 7 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv(SRC_CSV, delimiter=\"\\t\", index_col=False, header=None)\n",
"# data.columns = ['frame_id', 'track_id', 'pos_x', 'pos_y', 'width', 'height']#, '_x', '_y,']\n",
"data.columns = ['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state']#, '_x', '_y,']\n",
"data['frame_id'] = pd.to_numeric(data['frame_id'], downcast='integer')\n",
"data['frame_id'] = data['frame_id'] // 10 # compatibility with Trajectron++\n",
"\n",
"data.sort_values(by=['track_id', 'frame_id'],inplace=True)\n",
"\n",
"data.set_index(['track_id', 'frame_id'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# cm to meter\n",
"data['x'] = data['x']/100\n",
"data['y'] = data['y']/100"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff() #.fillna(0)\n",
"data['diff'] = pd.to_numeric(data['diff'], downcast='integer')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"326960it [06:37, 821.55it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"was: 326960 added: 85138 new length: 412098\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"missing=0\n",
"old_size=len(data)\n",
"# slow way to append missing steps to the dataset\n",
"for ind, row in tqdm(data.iterrows()):\n",
" if row['diff'] > 1:\n",
" for s in range(1, int(row['diff'])):\n",
" # add as many entries as missing\n",
" missing += 1\n",
" data.loc[len(data)] = [row['frame_id']-s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, 1]\n",
" # new_frame = [data.loc[ind-1]['frame_id']+s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan]\n",
" # data.loc[len(data)] = new_frame\n",
"\n",
"print('was:', old_size, 'added:', missing, 'new length:', len(data))\n",
"# now sort, so that the added data is in the right place\n",
"data.sort_values(by=['track_id', 'frame_id'], inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# interpolate missing data\n",
"df=data.copy()\n",
"df = df.groupby('track_id').apply(lambda group: group.interpolate(method='linear'))\n",
"df.reset_index(drop=True, inplace=True)\n",
"data = df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running smoother\n"
]
}
],
"source": [
"from trap.tracker import Smoother\n",
"\n",
"if SMOOTHING:\n",
" df=data.copy()\n",
" if 'x_raw' not in df:\n",
" df['x_raw'] = df['x']\n",
" if 'y_raw' not in df:\n",
" df['y_raw'] = df['y']\n",
"\n",
" print(\"Running smoother\")\n",
" # print(df)\n",
" # from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother\n",
" smoother = Smoother(convolution=False)\n",
" def smoothing(data):\n",
" # smoother = ConvolutionSmoother(window_len=SMOOTHING_WINDOW, window_type='ones', copy=None)\n",
" return smoother.smooth(data).tolist()\n",
" # df=df.assign(smooth_data=smoother.smooth_data[0])\n",
" # return smoother.smooth_data[0].tolist()\n",
"\n",
" # operate smoothing per axis\n",
" df['x'] = df.groupby('track_id')['x_raw'].transform(smoothing)\n",
" df['y'] = df.groupby('track_id')['y_raw'].transform(smoothing)\n",
" \n",
"\n",
" data = df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>x_raw</th>\n",
" <th>y_raw</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>0.881595</td>\n",
" <td>7.341152</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.870703</td>\n",
" <td>7.309168</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.901374</td>\n",
" <td>7.370044</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.924360</td>\n",
" <td>7.432365</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1.0</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.906583</td>\n",
" <td>7.456334</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632.0</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>15.214551</td>\n",
" <td>10.027093</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632.0</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.244872</td>\n",
" <td>10.047117</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632.0</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.318496</td>\n",
" <td>10.015218</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632.0</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.400203</td>\n",
" <td>9.935355</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632.0</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.416893</td>\n",
" <td>10.051785</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1.0 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1.0 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1.0 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1.0 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1.0 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632.0 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632.0 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632.0 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632.0 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632.0 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff x_raw y_raw \n",
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 \n",
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 \n",
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 \n",
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 \n",
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 \n",
"... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 \n",
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 \n",
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 \n",
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 \n",
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 \n",
"\n",
"[320188 rows x 12 columns]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# del data['diff']\n",
"# recalculate diff\n",
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff()\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>x_raw</th>\n",
" <th>y_raw</th>\n",
" <th>dt</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>0.881595</td>\n",
" <td>7.341152</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.870703</td>\n",
" <td>7.309168</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.901374</td>\n",
" <td>7.370044</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.924360</td>\n",
" <td>7.432365</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.906583</td>\n",
" <td>7.456334</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>15.214551</td>\n",
" <td>10.027093</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.244872</td>\n",
" <td>10.047117</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.318496</td>\n",
" <td>10.015218</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.400203</td>\n",
" <td>9.935355</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.416893</td>\n",
" <td>10.051785</td>\n",
" <td>0.083333</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 13 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff x_raw y_raw dt \n",
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
"... ... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
"\n",
"[320188 rows x 13 columns]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# data['node_type'] = 'PEDESTRIAN' # compatibility with Trajectron++\n",
"# data['node_id'] = data['track_id'].astype(str)\n",
"data['track_id'] = pd.to_numeric(data['track_id'], downcast='integer')\n",
"\n",
"\n",
"data['dt'] = data['diff'] * (1/FPS)\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# position into an average coordinate system. (DO THESE NEED TO BE STORED?)\n",
"# Don't do this, messes up\n",
"# data['pos_x'] = data['pos_x'] - data['pos_x'].mean()\n",
"# data['pos_y'] = data['pos_y'] - data['pos_y'].mean()\n",
" \n",
"# data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAGdCAYAAAD+JxxnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxGUlEQVR4nO3dfVTUdd7/8RcgDFIO3gXIJRplpZRK4UbTrTfIqJxObm5rN8fIvDm60FnkXFqU4V277uXmXYVx2lLaU5a6p9xSL2TCVStHTZQrtfTqxi53Tw12o6KYwwjz+6Mf35wwZVydCT7Pxzmcs/P9vuc773nzYX31ne8XIvx+v18AAAAGigx3AwAAAOFCEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGKtduBv4JWtsbNSXX36pDh06KCIiItztAACAFvD7/Tp27JiSk5MVGXn2cz4EobP48ssvlZKSEu42AADAefjnP/+p7t27n7WGIHQWHTp0kPTDIO12e5i7CT+fz6eKigplZ2crOjo63O20Wcw5NJhz6DDr0GDOP6qtrVVKSor17/jZEITOounjMLvdThDSDz9kcXFxstvtxv+QXUzMOTSYc+gw69Bgzs215LIWLpYGAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMFa7cDcAAOF03cz18jZEhLuNFvviTznhbgFoUzgjBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADBWUEHo+eefV79+/WS322W32+VwOPTf//3f1v6TJ08qLy9PXbp00aWXXqpRo0appqYm4BgHDx5UTk6O4uLilJCQoKlTp+rUqVMBNRs3btQNN9wgm82mXr16qaysrFkvJSUluvzyyxUbG6vMzExt3749YH9LegEAAGYLKgh1795df/rTn1RVVaUdO3Zo8ODBuuuuu7R3715J0pQpU/T2229r1apV2rRpk7788kvdfffd1vMbGhqUk5Oj+vp6bdmyRS+//LLKyspUXFxs1Rw4cEA5OTkaNGiQqqurVVBQoPHjx2v9+vVWzYoVK1RYWKgZM2Zo586d6t+/v5xOpw4dOmTVnKsXAACAoILQnXfeqREjRuiqq67S1VdfrT/84Q+69NJLtXXrVh09elQvvfSSFixYoMGDBysjI0PLli3Tli1btHXrVklSRUWFPvroI73yyitKT0/X8OHDNWfOHJWUlKi+vl6SVFpaqtTUVM2fP199+vRRfn6+fvOb32jhwoVWHwsWLNCECRM0duxYpaWlqbS0VHFxcVq6dKkktagXAACAduf7xIaGBq1atUp1dXVyOByqqqqSz+dTVlaWVdO7d2/16NFDbrdbN910k9xut/r27avExESrxul0avLkydq7d6+uv/56ud3ugGM01RQUFEiS6uvrVVVVpaKiImt/ZGSksrKy5Ha7JalFvZyJ1+uV1+u1HtfW1kqSfD6ffD7feU6q7WiaAbO4uJhzaDTN1xbpD3MnwWmN64I1HRrM+UfBzCDoILR79245HA6dPHlSl156qd58802lpaWpurpaMTEx6tixY0B9YmKiPB6PJMnj8QSEoKb9TfvOVlNbW6vvv/9ehw8fVkNDwxlr9u3bZx3jXL2cydy5czVr1qxm2ysqKhQXF/ezzzONy+UKdwtGYM6hMWdAY7hbCMq6devC3cJ5Y02HBnOWTpw40eLaoIPQNddco+rqah09elR/+9vflJubq02bNgV7mF+koqIiFRYWWo9ra2uVkpKi7Oxs2e32MHb2y+Dz+eRyuTR06FBFR0eHu502izmHRtOcn9wRKW9jRLjbabE9M53hbiForOnQYM4/avpEpyWCDkIxMTHq1auXJCkjI0MffPCBFi9erNGjR6u+vl5HjhwJOBNTU1OjpKQkSVJSUlKzu7ua7uQ6veand3fV1NTIbrerffv2ioqKUlRU1BlrTj/GuXo5E5vNJpvN1mx7dHS08YvqdMwjNJhzaHgbI+RtaD1BqDWvCdZ0aDDn4H5O/u3fI9TY2Civ16uMjAxFR0ersrLS2rd//34dPHhQDodDkuRwOLR79+6Au7tcLpfsdrvS0tKsmtOP0VTTdIyYmBhlZGQE1DQ2NqqystKqaUkvAAAAQZ0RKioq0vDhw9WjRw8dO3ZMy5cv18aNG7V+/XrFx8dr3LhxKiwsVOfOnWW32/XII4/I4XBYFydnZ2crLS1NY8aM0bx58+TxeDR9+nTl5eVZZ2ImTZqk5557TtOmTdPDDz+sDRs2aOXKlVq7dq3VR2FhoXJzczVgwADdeOONWrRokerq6jR27FhJalEvAAAAQQWhQ4cO6cEHH9RXX32l+Ph49evXT+vXr9fQoUMlSQsXLlRkZKRGjRolr9crp9OpJUuWWM+PiorSmjVrNHnyZDkcDl1yySXKzc3V7NmzrZrU1FStXbtWU6ZM0eLFi9W9e3e9+OKLcjp//Fx89OjR+vrrr1VcXCyPx6P09HSVl5cHXEB9rl4AAAAi/H5/67p3NIRqa2sVHx+vo0ePcrG0frgQb926dRoxYoTxnz9fTMw5NJrmPG17VKu6RuiLP+WEu4WgsaZDgzn/KJh/v/lbYwAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABgrqCA0d+5c/epXv1KHDh2UkJCgkSNHav/+/QE1AwcOVERERMDXpEmTAmoOHjyonJwcxcXFKSEhQVOnTtWpU6cCajZu3KgbbrhBNptNvXr1UllZWbN+SkpKdPnllys2NlaZmZnavn17wP6TJ08qLy9PXbp00aWXXqpRo0appqYmmLcMAADasKCC0KZNm5SXl6etW7fK5XLJ5/MpOztbdXV1AXUTJkzQV199ZX3NmzfP2tfQ0KCcnBzV19dry5Ytevnll1VWVqbi4mKr5sCBA8rJydGgQYNUXV2tgoICjR8/XuvXr7dqVqxYocLCQs2YMUM7d+5U//795XQ6dejQIatmypQpevvtt7Vq1Spt2rRJX375pe6+++6ghwQAANqmdsEUl5eXBzwuKytTQkKCqqqqdPvtt1vb4+LilJSUdMZjVFRU6KOPPtI777yjxMREpaena86cOXr00Uc1c+ZMxcTEqLS0VKmpqZo/f74kqU+fPnrvvfe0cOFCOZ1OSdKCBQs0YcIEjR07VpJUWlqqtWvXaunSpXrsscd09OhRvfTSS1q+fLkGDx4sSVq2bJn69OmjrVu36qabbgrmrQMAgDYoqCD0U0ePHpUkde7cOWD7q6++qldeeUVJSUm688479eSTTyouLk6S5Ha71bdvXyUmJlr1TqdTkydP1t69e3X99dfL7XYrKysr4JhOp1MFBQWSpPr6elVVVamoqMjaHxkZqaysLLndbklSVVWVfD5fwHF69+6tHj16yO12nzEIeb1eeb1e63Ftba0kyefzyefzBT2ftqZpBszi4mLOodE0X1ukP8ydBKc1rgvWdGgw5x8FM4PzDkKNjY0qKCjQLbfcouuuu87afv/996tnz55KTk7Whx9+qEcffVT79+/XG2+8IUnyeDwBIUiS9djj8Zy1pra2Vt9//70OHz6shoaGM9bs27fPOkZMTIw6duzYrKbpdX5q7ty5mjVrVrPtFRUVVpC
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data['diff'].hist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is a bit crappy because it has different frame step: ranging from predominantly 1 and 2 to sometimes have 3 and 4 as well. This inevitabily leads to difference in speed caluclations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if SRC_H is not None:\n",
" H = np.loadtxt(SRC_H, delimiter=',')\n",
"else:\n",
" H = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No H given, probably already projected data?\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>x_raw</th>\n",
" <th>y_raw</th>\n",
" <th>dt</th>\n",
" <th>proj_x</th>\n",
" <th>proj_y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>0.881595</td>\n",
" <td>7.341152</td>\n",
" <td>NaN</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.870703</td>\n",
" <td>7.309168</td>\n",
" <td>0.083333</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.901374</td>\n",
" <td>7.370044</td>\n",
" <td>0.083333</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.924360</td>\n",
" <td>7.432365</td>\n",
" <td>0.083333</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.906583</td>\n",
" <td>7.456334</td>\n",
" <td>0.083333</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>15.214551</td>\n",
" <td>10.027093</td>\n",
" <td>NaN</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.244872</td>\n",
" <td>10.047117</td>\n",
" <td>0.083333</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.318496</td>\n",
" <td>10.015218</td>\n",
" <td>0.083333</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.400203</td>\n",
" <td>9.935355</td>\n",
" <td>0.083333</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>15.416893</td>\n",
" <td>10.051785</td>\n",
" <td>0.083333</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 15 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff x_raw y_raw dt \\\n",
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
"... ... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
"\n",
" proj_x proj_y \n",
"0 0.855100 7.136193 \n",
"1 0.873132 7.235233 \n",
"2 0.890957 7.328989 \n",
"3 0.907784 7.418187 \n",
"4 0.923439 7.505012 \n",
"... ... ... \n",
"320183 14.840476 9.786501 \n",
"320184 15.033432 9.870472 \n",
"320185 15.211560 9.943236 \n",
"320186 15.377673 10.008965 \n",
"320187 15.538255 10.075935 \n",
"\n",
"[320188 rows x 15 columns]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"if H is not None:\n",
" print(\"Projecting data\")\n",
" data['foot_x'] = data['pos_x'] + 0.5 * data['width']\n",
" data['foot_y'] = data['pos_y'] + 0.5 * data['height']\n",
" \n",
" transformed = cv2.perspectiveTransform(np.array([data[['foot_x','foot_y']].to_numpy()]),H)[0]\n",
" data['proj_x'], data['proj_y'] = transformed[:,0], transformed[:,1]\n",
" data['proj_x'] = data['proj_x'].div(100) # cm to m\n",
" data['proj_y'] = data['proj_y'].div(100) # cm to m\n",
" # and shift to mean (THES NEED TO BE STORED AND REUSED IN LIVE SETTING)\n",
" mean_x = data['proj_x'].mean()\n",
" mean_y = data['proj_y'].mean()\n",
" data['proj_x'] = data['proj_x'] - data['proj_x'].mean()\n",
" data['proj_y'] = data['proj_y'] - data['proj_y'].mean()\n",
"else:\n",
" print(\"No H given, probably already projected data?\")\n",
" mean_x = 0\n",
" mean_y = 0\n",
" data['proj_x'] = data['x']\n",
" data['proj_y'] = data['y']\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Deriving displacement, velocity and accelation from x and y\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>...</th>\n",
" <th>y_raw</th>\n",
" <th>dt</th>\n",
" <th>proj_x</th>\n",
" <th>proj_y</th>\n",
" <th>dx</th>\n",
" <th>dy</th>\n",
" <th>vx</th>\n",
" <th>vy</th>\n",
" <th>ax</th>\n",
" <th>ay</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>7.341152</td>\n",
" <td>NaN</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>7.309168</td>\n",
" <td>0.083333</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>0.018032</td>\n",
" <td>0.099039</td>\n",
" <td>0.216383</td>\n",
" <td>1.188473</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>7.370044</td>\n",
" <td>0.083333</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>0.017825</td>\n",
" <td>0.093756</td>\n",
" <td>0.213899</td>\n",
" <td>1.125077</td>\n",
" <td>-0.029812</td>\n",
" <td>-0.760753</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>7.432365</td>\n",
" <td>0.083333</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>0.016827</td>\n",
" <td>0.089198</td>\n",
" <td>0.201924</td>\n",
" <td>1.070371</td>\n",
" <td>-0.143699</td>\n",
" <td>-0.656466</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>7.456334</td>\n",
" <td>0.083333</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>0.015655</td>\n",
" <td>0.086825</td>\n",
" <td>0.187865</td>\n",
" <td>1.041902</td>\n",
" <td>-0.168701</td>\n",
" <td>-0.341637</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>10.027093</td>\n",
" <td>NaN</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>10.047117</td>\n",
" <td>0.083333</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>0.192955</td>\n",
" <td>0.083971</td>\n",
" <td>2.315463</td>\n",
" <td>1.007656</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>10.015218</td>\n",
" <td>0.083333</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>0.178128</td>\n",
" <td>0.072764</td>\n",
" <td>2.137542</td>\n",
" <td>0.873173</td>\n",
" <td>-2.135059</td>\n",
" <td>-1.613797</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>9.935355</td>\n",
" <td>0.083333</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>0.166113</td>\n",
" <td>0.065728</td>\n",
" <td>1.993352</td>\n",
" <td>0.788742</td>\n",
" <td>-1.730279</td>\n",
" <td>-1.013172</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>10.051785</td>\n",
" <td>0.083333</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>0.160582</td>\n",
" <td>0.066971</td>\n",
" <td>1.926987</td>\n",
" <td>0.803649</td>\n",
" <td>-0.796376</td>\n",
" <td>0.178886</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff ... y_raw dt \\\n",
"0 0.855100 7.136193 2.0 NaN ... 7.341152 NaN \n",
"1 0.873132 7.235233 2.0 1.0 ... 7.309168 0.083333 \n",
"2 0.890957 7.328989 2.0 1.0 ... 7.370044 0.083333 \n",
"3 0.907784 7.418187 2.0 1.0 ... 7.432365 0.083333 \n",
"4 0.923439 7.505012 2.0 1.0 ... 7.456334 0.083333 \n",
"... ... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN ... 10.027093 NaN \n",
"320184 15.033432 9.870472 2.0 1.0 ... 10.047117 0.083333 \n",
"320185 15.211560 9.943236 2.0 1.0 ... 10.015218 0.083333 \n",
"320186 15.377673 10.008965 2.0 1.0 ... 9.935355 0.083333 \n",
"320187 15.538255 10.075935 2.0 1.0 ... 10.051785 0.083333 \n",
"\n",
" proj_x proj_y dx dy vx vy \\\n",
"0 0.855100 7.136193 NaN NaN NaN NaN \n",
"1 0.873132 7.235233 0.018032 0.099039 0.216383 1.188473 \n",
"2 0.890957 7.328989 0.017825 0.093756 0.213899 1.125077 \n",
"3 0.907784 7.418187 0.016827 0.089198 0.201924 1.070371 \n",
"4 0.923439 7.505012 0.015655 0.086825 0.187865 1.041902 \n",
"... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 NaN NaN NaN NaN \n",
"320184 15.033432 9.870472 0.192955 0.083971 2.315463 1.007656 \n",
"320185 15.211560 9.943236 0.178128 0.072764 2.137542 0.873173 \n",
"320186 15.377673 10.008965 0.166113 0.065728 1.993352 0.788742 \n",
"320187 15.538255 10.075935 0.160582 0.066971 1.926987 0.803649 \n",
"\n",
" ax ay \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 -0.029812 -0.760753 \n",
"3 -0.143699 -0.656466 \n",
"4 -0.168701 -0.341637 \n",
"... ... ... \n",
"320183 NaN NaN \n",
"320184 NaN NaN \n",
"320185 -2.135059 -1.613797 \n",
"320186 -1.730279 -1.013172 \n",
"320187 -0.796376 0.178886 \n",
"\n",
"[320188 rows x 21 columns]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"Deriving displacement, velocity and accelation from x and y\")\n",
"data['dx'] = data.groupby(['track_id'])['proj_x'].diff()\n",
"data['dy'] = data.groupby(['track_id'])['proj_y'].diff()\n",
"data['vx'] = data['dx'].div(data['dt'], axis=0)\n",
"data['vy'] = data['dy'].div(data['dt'], axis=0)\n",
"\n",
"data['ax'] = data.groupby(['track_id'])['vx'].diff().div(data['dt'], axis=0)\n",
"data['ay'] = data.groupby(['track_id'])['vy'].diff().div(data['dt'], axis=0)\n",
"\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>...</th>\n",
" <th>dx</th>\n",
" <th>dy</th>\n",
" <th>vx</th>\n",
" <th>vy</th>\n",
" <th>ax</th>\n",
" <th>ay</th>\n",
" <th>v</th>\n",
" <th>a</th>\n",
" <th>heading</th>\n",
" <th>d_heading</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.018032</td>\n",
" <td>0.099039</td>\n",
" <td>0.216383</td>\n",
" <td>1.188473</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.208011</td>\n",
" <td>NaN</td>\n",
" <td>79.681298</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.017825</td>\n",
" <td>0.093756</td>\n",
" <td>0.213899</td>\n",
" <td>1.125077</td>\n",
" <td>-0.029812</td>\n",
" <td>-0.760753</td>\n",
" <td>1.145230</td>\n",
" <td>-0.753373</td>\n",
" <td>79.235449</td>\n",
" <td>-5.350188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.016827</td>\n",
" <td>0.089198</td>\n",
" <td>0.201924</td>\n",
" <td>1.070371</td>\n",
" <td>-0.143699</td>\n",
" <td>-0.656466</td>\n",
" <td>1.089251</td>\n",
" <td>-0.671740</td>\n",
" <td>79.316807</td>\n",
" <td>0.976297</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.015655</td>\n",
" <td>0.086825</td>\n",
" <td>0.187865</td>\n",
" <td>1.041902</td>\n",
" <td>-0.168701</td>\n",
" <td>-0.341637</td>\n",
" <td>1.058703</td>\n",
" <td>-0.366576</td>\n",
" <td>79.778828</td>\n",
" <td>5.544252</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.192955</td>\n",
" <td>0.083971</td>\n",
" <td>2.315463</td>\n",
" <td>1.007656</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.525221</td>\n",
" <td>NaN</td>\n",
" <td>23.517970</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.178128</td>\n",
" <td>0.072764</td>\n",
" <td>2.137542</td>\n",
" <td>0.873173</td>\n",
" <td>-2.135059</td>\n",
" <td>-1.613797</td>\n",
" <td>2.309007</td>\n",
" <td>-2.594562</td>\n",
" <td>22.219713</td>\n",
" <td>-15.579091</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.166113</td>\n",
" <td>0.065728</td>\n",
" <td>1.993352</td>\n",
" <td>0.788742</td>\n",
" <td>-1.730279</td>\n",
" <td>-1.013172</td>\n",
" <td>2.143727</td>\n",
" <td>-1.983366</td>\n",
" <td>21.588019</td>\n",
" <td>-7.580324</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.160582</td>\n",
" <td>0.066971</td>\n",
" <td>1.926987</td>\n",
" <td>0.803649</td>\n",
" <td>-0.796376</td>\n",
" <td>0.178886</td>\n",
" <td>2.087853</td>\n",
" <td>-0.670484</td>\n",
" <td>22.638547</td>\n",
" <td>12.606340</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff ... dx dy vx \\\n",
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
"... ... ... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
"\n",
" vy ax ay v a heading d_heading \n",
"0 NaN NaN NaN NaN NaN NaN NaN \n",
"1 1.188473 NaN NaN 1.208011 NaN 79.681298 NaN \n",
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
"... ... ... ... ... ... ... ... \n",
"320183 NaN NaN NaN NaN NaN NaN NaN \n",
"320184 1.007656 NaN NaN 2.525221 NaN 23.517970 NaN \n",
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
"\n",
"[320188 rows x 25 columns]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# then we need the velocity itself\n",
"data['v'] = np.sqrt(data['vx'].pow(2) + data['vy'].pow(2))\n",
"# and derive acceleration\n",
"data['a'] = data.groupby(['track_id'])['v'].diff().div(data['dt'], axis=0)\n",
"\n",
"# we can calculate heading based on the velocity components\n",
"data['heading'] = (np.arctan2(data['vy'], data['vx']) * 180 / np.pi) % 360\n",
"\n",
"# and derive it to get the rate of change of the heading\n",
"data['d_heading'] = data.groupby(['track_id'])['heading'].diff().div(data['dt'], axis=0)\n",
"\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>track_id</th>\n",
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
" <th>...</th>\n",
" <th>dx</th>\n",
" <th>dy</th>\n",
" <th>vx</th>\n",
" <th>vy</th>\n",
" <th>ax</th>\n",
" <th>ay</th>\n",
" <th>v</th>\n",
" <th>a</th>\n",
" <th>heading</th>\n",
" <th>d_heading</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.566162</td>\n",
" <td>88.795326</td>\n",
" <td>173.917542</td>\n",
" <td>0.855100</td>\n",
" <td>7.136193</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.208011</td>\n",
" <td>-0.753373</td>\n",
" <td>79.681298</td>\n",
" <td>-5.350188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>10.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>565.116699</td>\n",
" <td>88.801704</td>\n",
" <td>171.334290</td>\n",
" <td>0.873132</td>\n",
" <td>7.235233</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.018032</td>\n",
" <td>0.099039</td>\n",
" <td>0.216383</td>\n",
" <td>1.188473</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.208011</td>\n",
" <td>-0.753373</td>\n",
" <td>79.681298</td>\n",
" <td>-5.350188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>11.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874573</td>\n",
" <td>90.596596</td>\n",
" <td>177.199951</td>\n",
" <td>0.890957</td>\n",
" <td>7.328989</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.017825</td>\n",
" <td>0.093756</td>\n",
" <td>0.213899</td>\n",
" <td>1.125077</td>\n",
" <td>-0.029812</td>\n",
" <td>-0.760753</td>\n",
" <td>1.145230</td>\n",
" <td>-0.753373</td>\n",
" <td>79.235449</td>\n",
" <td>-5.350188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>12.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>564.874268</td>\n",
" <td>90.928131</td>\n",
" <td>183.125732</td>\n",
" <td>0.907784</td>\n",
" <td>7.418187</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.016827</td>\n",
" <td>0.089198</td>\n",
" <td>0.201924</td>\n",
" <td>1.070371</td>\n",
" <td>-0.143699</td>\n",
" <td>-0.656466</td>\n",
" <td>1.089251</td>\n",
" <td>-0.671740</td>\n",
" <td>79.316807</td>\n",
" <td>0.976297</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>13.0</td>\n",
" <td>1</td>\n",
" <td>0.000000</td>\n",
" <td>569.931213</td>\n",
" <td>86.213280</td>\n",
" <td>180.774292</td>\n",
" <td>0.923439</td>\n",
" <td>7.505012</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.015655</td>\n",
" <td>0.086825</td>\n",
" <td>0.187865</td>\n",
" <td>1.041902</td>\n",
" <td>-0.168701</td>\n",
" <td>-0.341637</td>\n",
" <td>1.058703</td>\n",
" <td>-0.366576</td>\n",
" <td>79.778828</td>\n",
" <td>5.544252</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320183</th>\n",
" <td>60159.0</td>\n",
" <td>3632</td>\n",
" <td>1830.709717</td>\n",
" <td>651.257446</td>\n",
" <td>150.202515</td>\n",
" <td>157.239746</td>\n",
" <td>14.840476</td>\n",
" <td>9.786501</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.525221</td>\n",
" <td>-2.594562</td>\n",
" <td>23.517970</td>\n",
" <td>-15.579091</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320184</th>\n",
" <td>60160.0</td>\n",
" <td>3632</td>\n",
" <td>1834.013672</td>\n",
" <td>649.612122</td>\n",
" <td>153.686646</td>\n",
" <td>160.874023</td>\n",
" <td>15.033432</td>\n",
" <td>9.870472</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.192955</td>\n",
" <td>0.083971</td>\n",
" <td>2.315463</td>\n",
" <td>1.007656</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.525221</td>\n",
" <td>-2.594562</td>\n",
" <td>23.517970</td>\n",
" <td>-15.579091</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320185</th>\n",
" <td>60161.0</td>\n",
" <td>3632</td>\n",
" <td>1845.373047</td>\n",
" <td>651.249756</td>\n",
" <td>147.178589</td>\n",
" <td>153.729248</td>\n",
" <td>15.211560</td>\n",
" <td>9.943236</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.178128</td>\n",
" <td>0.072764</td>\n",
" <td>2.137542</td>\n",
" <td>0.873173</td>\n",
" <td>-2.135059</td>\n",
" <td>-1.613797</td>\n",
" <td>2.309007</td>\n",
" <td>-2.594562</td>\n",
" <td>22.219713</td>\n",
" <td>-15.579091</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320186</th>\n",
" <td>60162.0</td>\n",
" <td>3632</td>\n",
" <td>1857.388916</td>\n",
" <td>650.908203</td>\n",
" <td>136.407349</td>\n",
" <td>142.354614</td>\n",
" <td>15.377673</td>\n",
" <td>10.008965</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.166113</td>\n",
" <td>0.065728</td>\n",
" <td>1.993352</td>\n",
" <td>0.788742</td>\n",
" <td>-1.730279</td>\n",
" <td>-1.013172</td>\n",
" <td>2.143727</td>\n",
" <td>-1.983366</td>\n",
" <td>21.588019</td>\n",
" <td>-7.580324</td>\n",
" </tr>\n",
" <tr>\n",
" <th>320187</th>\n",
" <td>60163.0</td>\n",
" <td>3632</td>\n",
" <td>1862.792725</td>\n",
" <td>658.719971</td>\n",
" <td>141.984253</td>\n",
" <td>149.052307</td>\n",
" <td>15.538255</td>\n",
" <td>10.075935</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.160582</td>\n",
" <td>0.066971</td>\n",
" <td>1.926987</td>\n",
" <td>0.803649</td>\n",
" <td>-0.796376</td>\n",
" <td>0.178886</td>\n",
" <td>2.087853</td>\n",
" <td>-0.670484</td>\n",
" <td>22.638547</td>\n",
" <td>12.606340</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>320188 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" frame_id track_id l t w h \\\n",
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
"... ... ... ... ... ... ... \n",
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
"\n",
" x y state diff ... dx dy vx \\\n",
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
"... ... ... ... ... ... ... ... ... \n",
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
"\n",
" vy ax ay v a heading d_heading \n",
"0 NaN NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
"1 1.188473 NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
"... ... ... ... ... ... ... ... \n",
"320183 NaN NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
"320184 1.007656 NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
"\n",
"[320188 rows x 25 columns]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we can backfill the v and a, so that our model can make estimations\n",
"# based on these assumed values\n",
"data['v'] = data.groupby(['track_id'])['v'].bfill()\n",
"data['a'] = data.groupby(['track_id'])['a'].bfill()\n",
"\n",
"data['heading'] = data.groupby(['track_id'])['heading'].bfill()\n",
"data['d_heading'] = data.groupby(['track_id'])['d_heading'].bfill()\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"312423 items in filtered set, out of 320188 in total set\n"
]
}
],
"source": [
"filtered_data = data.groupby(['track_id']).filter(lambda group: len(group) >= window+1) # a lenght of 3 is neccessary to have all relevant derivatives of position\n",
"filtered_data = filtered_data.set_index(['track_id', 'frame_id']) # use for quick access\n",
"print(filtered_data.shape[0], \"items in filtered set, out of\", data.shape[0], \"in total set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1263 training tracks, 316 test tracks\n"
]
}
],
"source": [
"track_ids = filtered_data.index.unique('track_id').to_numpy()\n",
"np.random.shuffle(track_ids)\n",
"test_offset_idx = int(len(track_ids) * .8)\n",
"training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n",
"print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"here, draw out a sample track to see if it looks alright. **unfortunately the imate isn't mapped properly**."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1058\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3wcZ53/3zOzXbvqvcuW5W7LPXGL03F6AZIAoYSSXDg4yh0cEDjg4McFOEghcIQEkpCQRpqdnjhO3Fss925Vq3dp+87M8/tjVCxLtuXEVX7er5e8q5lnZp5Zr2Y+862KEEIgkUgkEolEcppQz/QEJBKJRCKRnF9I8SGRSCQSieS0IsWHRCKRSCSS04oUHxKJRCKRSE4rUnxIJBKJRCI5rUjxIZFIJBKJ5LQixYdEIpFIJJLTihQfEolEIpFITiu2Mz2BIzFNk7q6Onw+H4qinOnpSCQSiUQiGQZCCLq7u8nOzkZVj23bOOvER11dHXl5eWd6GhKJRCKRSD4CNTU15ObmHnPMWSc+fD4fYE0+Pj7+DM9GIpFIJBLJcOjq6iIvL6/vPn4szjrx0etqiY+Pl+JDIpFIJJJzjOGETMiAU4lEIpFIJKeVExYfK1as4NprryU7OxtFUXj55ZePOvauu+5CURTuu+++jzFFiUQikUgkI4kTFh+BQICpU6fy0EMPHXPcSy+9xLp168jOzv7Ik5NIJBKJRDLyOOGYj8WLF7N48eJjjqmtreUb3/gGb731FldfffVHnpxEIpFIJJKRx0kPODVNk9tvv53/+I//YOLEiccdH4lEiEQifb93dXWd7ClJJBKJRCI5izjpAaf33nsvNpuNb37zm8Ma/6tf/YqEhIS+H1njQyKRSCSSkc1JFR8ffvgh999/P4899tiwq5P+4Ac/oLOzs++npqbmZE5JIpFIJBLJWcZJFR8rV66kqamJ/Px8bDYbNpuNqqoqvvvd71JYWDjkNk6ns6+mh6ztIZFIJBLJyOekxnzcfvvtXHbZZQOWXXnlldx+++186UtfOpmHkkgkEsl5jhCCSHknzlEJshfYOcYJiw+/38+BAwf6fq+oqGDLli0kJyeTn59PSkrKgPF2u53MzEzGjh378WcrkUgkEgmW8GhfVknw3UMk3j4O78S0Mz0lyQlwwm6XTZs2MW3aNKZNmwbAd77zHaZNm8ZPfvKTkz45iUQikUiGIry3jeC7hwDofH4/QogzPCPJiXDClo9Fixad0H9yZWXliR5CIpFIJJJBCFMQWF+PGdIJfNjYvzxs0PbcXuxpHlS3jbg5WSiqdMOczZx1jeUkEonkXEMIQd2+DrJLEmXswSlCmIKOD6oIvDV0RmSorJlQz3vH5GQcXtfpm5zkhJGN5SQSieRjUrWjlZd/X0b1ztYzPZURixGOHVV4HElgZR3ClG6YsxkpPiQSieRjsvW9mgGvkpNPtLxz2GMDa+oxI/opnI3k4yLdLhKJRHKCCFOwY0UtkaB1gzu0p73ntYNNr1cC4PTYmLQwR8YefExMw6TztQqilcMXH8jg07MeKT4kEonkBImGdda8eAA9ag5YLkzB+iXlANgcKiWzM3B67Gdiiuc8vcGl4cpOwltbTnj74MZGvPOl+DtbkW4XiUQi+QgcL+tPpn5+PETMoPPtqo8kPNAFne9USdfLWYwUHxKJRHKCKKqCph378qlpqnzq/ggIU+BfW4d/dR1xMzPgo36EUvyd1UjxIZFIJMdACEH4YMcAS4bDZeOWe2bjjh/apWJ3aUxYkMPedQ0fO+tCCEFg/YZBlhQhBLV720echaXX4tH1dhX+lbXwMU4vuLFRZr2cpUjxIZFIJMcgvK+dlr9sJ7KvfcByp8dGqDs25DaxsMGWd6pZ8+IBouGPZ/oPrFxJ9Re+QGDVqgHLq3e2Wem9u9o+1v7PNlSnjYx/m4ajwHfsgc7jmESk6+WsRooPiUQiOQah7VbMQXB7ywALyHBcKqYp2LW6/mM9fXe9+daAV7CsHtuWW2m9Bzc3feR9n63YEl2kfXUKiuOIW5QGznFJADjHJB1/R0JI68dZisx2kUgkksM4vIS3EIJgmXVzD21tJripEc/0dBx5PuLmZJE/IZnqnW0IBF2KIEEMvFmaumDD0nImzMsadtaLME3an34as7sbYRh0vvYaAF2vv05bzIcRM+mOOaiOTQVF5eCHTcSnuIGRld4bPdSNOCKbCAN8F+USV5qOluikvTOGXtN99J3ogq7lNcTNyURxytvd2YT835BIJJLD6I05ECH9iOXWjTC4uYnQ7jY809OZemkeNXvaKFdM3vREuavLhXpEhKQRM9m5qo5pl+UPz1oSDNH8wIOYnQPrWohwGHXJY2hAd97lMKoUgFjU6EvvdXpsjL0gE4fr3L+0h3Zb7iTXxBQSryqi47VywrvaCO9pJ3FxEQC2ZOcxxYfiseGdl4Ni107LnCXDR7pdJBKJ5DD6Yg7yjx5z4BiVQMfScpI6wuSPS2Kf3SCgwj67MWisELDupYNsfqdqWOZ/1eMm+Qufx5aVNXhHQMSdTEPmbOjpISN6jAPxqS4+/cNZI0J4ALjHJ5N861hSPjceW4qblNsnkHzrWNzjk/vGxM04djaMCOp0L6/Bv6pWul7OMkbGt1QikUhOIrZEFylfmED9f68fcn1oZwsvESOgQaPPyV6HJTre9MTI69KIEwodikmCUFBQEALWv1KOqqqUXpp3TAuIGQzR9vgTgywfvThDbRiae9DyQGcEp2fkXNKdhQkDflcUBU9p+oBl7pJk7GMSie3rOPqOdJPONysRQuBbkDsiXFIjAWn5kEgkkiHQm0NHXRcGHiXMw0aYVzo6ifbcz2IKvBAXQUeQKFQOt4MIE9YvKWfLspqjPoUL06TzlZdJvPkm1JQUgke4CxSg05tLxDU42FJRlPPyxho/P+f4g0xB17vVMvPlLEKKD4lEIhmC3piDoa6SHhQeJY6sIWz+jZpgldNKwdUAcVihCiNmDhIgQgjWHmzFNE3W7jhE/e8foO2vf6MhGqI+cbDrJ+TOGHK+Y+eMjFiPE8Vdkox30TAEyAirh3KuI8WHRCKRHIEQgp1ZlSTdUkLSrWOHHJOGiv8oFbA2ugzKHDrKEOLEiJlsWFreV//j/X3N3PaXdfxh+QE++9RW/p57IeGkNBoS4tBMc9D2TenTBy1TFCgqTT2RUxxRmEdmxUjOeqT4kEgkkh6EEGxs2MjK2pXctfWbbE7bx4aKdZb14ggdYUPhPxkce9E7bqMjRpdiHlWA7FxVhzAFb2yvB+ClzbVM69jKzM4yauwxahO9GKqCAFrjXOxPT6QyJZ6oI2HQ/lAga3Tixzv5cxhnSeLxB5lixFWDPZc5/2x0EolEcgRCCDY1biKkh/j6sq8zN3suAE/ufpI1HWv47byfMnF1+qDtLsLOpcRYxuBYgk4bPB8X4Q6/a5AAEQLWvHyQ5zZW83x3FwDVzV2UhqupTPehCAGKwt7sVNK6Q7TFudiflYJNN7D58nr2YaAoVkxIcmYcdsf5mU5q6iaBdQ3DGAhm1ADPqZ+T5PhIy4dEIjkv6bVyCCFYcWgFd7x1B/duuBeADQ0brNd66/Xhzif4v4zn0QrjBu3nB7gpPvJSqoDbgNG6xl6bMSDuo2+ICdk1UZSeVYZq46Ws63kq51ZCirN3kngjUQ54ShBRB3PK20G1gTAxjWjfvrpaw8Sig9N8RzpG1KDpwc1E97QffzCgqvKWd7YgLR8SieS8ZFXtKu5edje3jL2FLU1bAKjurgZANy1Lhi6s131GOfuSywk4dEqSsrm6YwGqUDEQuFD4JR4+h5/DO72ENNioGbhMg9FdGkPVN9VQKNJVmlUTQ/gJaR7aHUlsSZzKhR0bUQ0FVYAnFia3uxYNFwCKMNA0S6DYHCrTrsg/Ly0fsfYQeuPRs5IOJ+Urk7DFO0/xjCTDRcpAiURy3mEKkz9v/TMAz+97nr3te4e13bvRlTyWtRTHJ/MBK7yjDZMcVG7DMWi8z4AZEdsxn/JuCjj5UpfKZ2qf4zOHnsFuRnEIS8YEFS8AGeEWDmQmo2vWnmx6GK3H8qF
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import random\n",
"if H:\n",
" img_src = \"../DATASETS/hof/webcam20231103-2.png\"\n",
" # dst = cv2.warpPerspective(img_src,H,(2500,1920))\n",
" src_img = cv2.imread(img_src)\n",
" print(src_img.shape)\n",
" h1,w1 = src_img.shape[:2]\n",
" corners = np.float32([[0,0], [w1, 0], [0, h1], [w1, h1]])\n",
"\n",
" print(corners)\n",
" corners_projected = cv2.perspectiveTransform(corners.reshape((-1,4,2)), H)[0]\n",
" print(corners_projected)\n",
" [xmin, ymin] = np.int32(corners_projected.min(axis=0).ravel() - 0.5)\n",
" [xmax, ymax] = np.int32(corners_projected.max(axis=0).ravel() + 0.5)\n",
" print(xmin, xmax, ymin, ymax)\n",
"\n",
" dst = cv2.warpPerspective(src_img,H, (xmax, ymax))\n",
" def plot_track(track_id: int):\n",
" plt.gca().invert_yaxis()\n",
"\n",
" plt.imshow(dst, origin='lower', extent=[xmin/100-mean_x, xmax/100-mean_x, ymin/100-mean_y, ymax/100-mean_y])\n",
" # plot scatter plot with x and y data \n",
" \n",
" ax = plt.scatter(\n",
" filtered_data.loc[track_id,:]['proj_x'],\n",
" filtered_data.loc[track_id,:]['proj_y'],\n",
" marker=\"*\") \n",
" ax.axes.invert_yaxis()\n",
" plt.plot(\n",
" filtered_data.loc[track_id,:]['proj_x'],\n",
" filtered_data.loc[track_id,:]['proj_y']\n",
" )\n",
"else:\n",
" def plot_track(track_id: int):\n",
" ax = plt.scatter(\n",
" filtered_data.loc[track_id,:]['x'],\n",
" filtered_data.loc[track_id,:]['y'],\n",
" marker=\"*\") \n",
" plt.plot(\n",
" filtered_data.loc[track_id,:]['proj_x'],\n",
" filtered_data.loc[track_id,:]['proj_y']\n",
" )\n",
"\n",
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
"# _track_id = 2188\n",
"_track_id = random.choice(track_ids)\n",
"print(_track_id)\n",
"plot_track(_track_id)\n",
"\n",
"for track_id in random.choices(track_ids, k=100):\n",
" plot_track(track_id)\n",
" \n",
"print(mean_x, mean_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now make the dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# a=filtered_data.loc[1]\n",
"# min(a.index.tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1263 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1263/1263 [01:42<00:00, 12.36it/s]\n",
"100%|██████████| 316/316 [00:26<00:00, 11.75it/s]\n"
]
}
],
"source": [
"\n",
"\n",
"\n",
"def create_dataset(data, track_ids, window):\n",
" X, y, = [], []\n",
" for track_id in tqdm(track_ids):\n",
" df = data.loc[track_id]\n",
" start_frame = min(df.index.tolist())\n",
" for step in range(len(df)-window-1):\n",
" i = int(start_frame) + step\n",
" # print(step, int(start_frame), i)\n",
" feature = df.loc[i:i+window][in_fields]\n",
" # target = df.loc[i+1:i+window+1][out_fields]\n",
" target = df.loc[i+window+1][out_fields]\n",
" X.append(feature.values)\n",
" y.append(target.values)\n",
" \n",
" return torch.tensor(np.array(X), dtype=torch.float), torch.tensor(np.array(y), dtype=torch.float)\n",
"\n",
"X_train, y_train = create_dataset(filtered_data, training_ids, window)\n",
"X_test, y_test = create_dataset(filtered_data, test_ids, window)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train = X_train.to(device=device), y_train.to(device=device)\n",
"X_test, y_test = X_test.to(device=device), y_test.to(device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import TensorDataset, DataLoader\n",
"dataset_train = TensorDataset(X_train, y_train)\n",
"loader_train = DataLoader(dataset_train, shuffle=True, batch_size=8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model give output for all timesteps, this should improve training. But we use only the last timestep for the prediction process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LSTMModel(nn.Module):\n",
" # input_size : number of features in input at each time step\n",
" # hidden_size : Number of LSTM units \n",
" # num_layers : number of LSTM layers \n",
" def __init__(self, input_size, hidden_size, num_layers): \n",
" super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n",
" self.lin1 = nn.Linear(input_size, hidden_size//2)\n",
" self.lstm = nn.LSTM(hidden_size//2, hidden_size, num_layers, batch_first=True)\n",
" self.linear = nn.Linear(hidden_size, output_size)\n",
" # self.activation_v = nn.LeakyReLU(.01)\n",
" # self.activation_heading = torch.remainder()\n",
"\n",
" def forward(self, x): # defines forward pass of the neural network\n",
" out = self.lin1(x)\n",
" out, h0 = self.lstm(out)\n",
" # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n",
" # print(out.shape)\n",
" out = out[:, -1,:]\n",
" # print(out.shape)\n",
" out = self.linear(out)\n",
" \n",
" # torch.remainder(out[1], 360)\n",
" # print('o',out.shape)\n",
" return out\n",
"\n",
"model = LSTMModel(input_size, hidden_size, num_layers).to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"loss_fn = nn.MSELoss()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate():\n",
" # toggle evaluation mode\n",
" model.eval()\n",
" with torch.no_grad():\n",
" y_pred = model(X_train.to(device=device))\n",
" train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n",
" y_pred = model(X_test.to(device=device))\n",
" test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n",
" print(\"Epoch %d: train RMSE %.4f, test RMSE %.4f\" % (epoch, train_rmse, test_rmse))\n",
"\n",
"def load_most_recent():\n",
" paths = list(cache_path.glob(\"checkpoint_*.pt\"))\n",
" if len(paths) < 1:\n",
" print('Nothing found to load')\n",
" return None, None\n",
" paths.sort()\n",
"\n",
" print(f\"Loading {paths[-1]}\")\n",
" return load_cache(path=paths[-1])\n",
"\n",
"def load_cache(epoch=None, path=None):\n",
" if path is None:\n",
" if epoch is None:\n",
" raise RuntimeError(\"Either path or epoch must be given\")\n",
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
" else:\n",
" print (path.stem)\n",
" epoch = int(path.stem[-5:])\n",
"\n",
" cached = torch.load(path)\n",
" \n",
" optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
" model.load_state_dict(cached['model_state_dict'])\n",
" return epoch, cached['loss']\n",
" \n",
"\n",
"def cache(epoch, loss):\n",
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
" print(f\"Cache to {path}\")\n",
" torch.save({\n",
" 'epoch': epoch,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'loss': loss,\n",
" }, path)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading EXPERIMENTS/cache/hof2/checkpoint_00005.pt\n",
"checkpoint_00005\n",
"starting from epoch 5 (loss: nan)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▍ | 4/95 [07:37<2:53:27, 114.37s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cache to EXPERIMENTS/cache/hof2/checkpoint_00010.pt\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"ename": "RuntimeError",
"evalue": "CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[31], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 30\u001b[0m cache(epoch, loss)\n\u001b[0;32m---> 31\u001b[0m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m evaluate()\n",
"Cell \u001b[0;32mIn[30], line 5\u001b[0m, in \u001b[0;36mevaluate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m train_rmse \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msqrt(loss_fn(y_pred, y_train))\n\u001b[1;32m 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_test\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[28], line 15\u001b[0m, in \u001b[0;36mLSTMModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x): \u001b[38;5;66;03m# defines forward pass of the neural network\u001b[39;00m\n\u001b[1;32m 14\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin1(x)\n\u001b[0;32m---> 15\u001b[0m out, h0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# print(out.shape)\u001b[39;00m\n\u001b[1;32m 18\u001b[0m out \u001b[38;5;241m=\u001b[39m out[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:769\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_forward_args(\u001b[38;5;28minput\u001b[39m, hx, batch_sizes)\n\u001b[1;32m 768\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 769\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 772\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\u001b[38;5;28minput\u001b[39m, batch_sizes, hx, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_weights, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias,\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional)\n",
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
]
}
],
"source": [
"start_epoch, loss = load_most_recent()\n",
"if start_epoch is None:\n",
" start_epoch = 0\n",
"else:\n",
" print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n",
"\n",
"# Train Network\n",
"for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n",
" # toggle train mode\n",
" model.train()\n",
" for batch_idx, (data, targets) in enumerate(loader_train):\n",
" # Get data to cuda if possible\n",
" data = data.to(device=device).squeeze(1)\n",
" targets = targets.to(device=device)\n",
"\n",
" # forward\n",
" scores = model(data)\n",
" loss = loss_fn(scores, targets)\n",
"\n",
" # backward\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"\n",
" # gradient descent update step/adam step\n",
" optimizer.step()\n",
"\n",
" if epoch % 5 != 0:\n",
" continue\n",
"\n",
" cache(epoch, loss)\n",
" evaluate()\n",
"\n",
"evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'model' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 3\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_train\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
]
}
],
"source": [
"model.eval()\n",
"with torch.no_grad():\n",
" y_pred = model(X_train.to(device=device))\n",
" print(y_pred.shape, y_train.shape)\n",
"y_train, y_pred"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_x, mean_y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import scipy\n",
"\n",
"\n",
"def predict_and_plot(feature, steps = 50):\n",
" lenght = feature.shape[0]\n",
" # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n",
" # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n",
" with torch.no_grad():\n",
" for i in range(steps):\n",
" # predict_f = scipy.ndimage.uniform_filter(feature)\n",
" # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n",
" # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n",
" # bathc size of one, so feature as single item in array\n",
" X = torch.tensor([feature], dtype=torch.float).to(device)\n",
" # print(X.shape)\n",
" s = model(X)[0].cpu()\n",
" \n",
" # proj_x proj_y v heading a d_heading\n",
" # next_step = feature\n",
" dt = 1/ FPS\n",
" v = np.sqrt(s[0]**2 + s[1]**2)\n",
" heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n",
" a = (v - feature[-1][2]) / dt\n",
" d_heading = (heading - feature[-1][5])\n",
" # print(s)\n",
" feature = np.append(feature, [[feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]], axis=0)\n",
" \n",
" # print(next_step, nxt)\n",
" plt.plot(feature[:lenght,0], feature[:lenght,1], c='orange')\n",
" plt.plot(feature[lenght-1:,0], feature[lenght-1:,1], c='red')\n",
" plt.scatter(feature[lenght:,0], feature[lenght:,1], c='red')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2515\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAACPu0lEQVR4nO39eZRkd3nfj7/vvtRe1dv0aEYIicUIgYkxmBBiEfS1rAMYLDssJphgx45tYQIiGHMSIHhTMN8QASbEcI4Njg14ScA23xiHsMn+GTACY4ONtYCWkWZ6eqv13rr7/f3R/Txzq7tnk7pr6X5e5/SZ6Vq6PlV1qz7v+yzvR8nzPIcgCIIgCMKYUCe9AEEQBEEQjhYiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCv6pBewkyzLcPr0aVQqFSiKMunlCIIgCIJwCeR5jn6/j+XlZajqhWMbUyc+Tp8+jRMnTkx6GYIgCIIgPAJOnTqFK6644oK3mTrxUalUAGwtvlqtTng1giAIgiBcCr1eDydOnOB9/EJMnfigVEu1WhXxIQiCIAgzxqWUTEjBqSAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY+Wyxccdd9yBF77whVheXoaiKPjEJz5x3tv+zM/8DBRFwe233/4oligIgiAIwmHissWH53l46lOfive9730XvN3HP/5xfOlLX8Ly8vIjXpwgCIIgCIePyx4sd9NNN+Gmm2664G0efvhh/PzP/zz+/M//HM9//vMf8eIEQRAmSZ7niKIIURRBURTous4/giA8cvb9E5RlGV75ylfijW98I6699tr9/vOCIAgHCgmO4XCIIAiQZdmu25AQMQyDxYiqqjBNcwIrFoTZY9/Fxzve8Q7ouo7Xvva1l3T7MAwRhiH/3uv19ntJgiAIFyXPc3ieh8FgMCI4VFWFZVkAgCRJkCQJ8jxHHMeI43jX35FUsyBcnH0VH1/96lfx7ne/G1/72tegKMol3ee2227D29/+9v1chiAIwmXT6XQwHA4BbAkOx3Fg2zZM0xz5PsvzHGmaIo5jFiN0P0EQLo19bbX9i7/4C6yuruLkyZMcinzggQfwhje8AY95zGP2vM+b3/xmdLtd/jl16tR+LkkQBOGSKJVKUFUV9Xodi4uLqNVqsCxr14kUpVwcx0GlUkGj0cDS0hIMw0C9Xp/M4gVhxtjXyMcrX/lK3HDDDSOX3XjjjXjlK1+JV7/61Xvex7IsDmkKgiBMCtM0sbi4eMlR2yKqqmJ+fv4AViUIh5PLFh+DwQD33nsv/37ffffh61//OprNJk6ePIlWqzVye8MwsLS0hCc84QmPfrWCIAgHyCMRHoIgXD6XLT7uvPNOPPe5z+Xfb731VgDAq171KnzoQx/at4UJgiAIgnA4uWzxcf311yPP80u+/f3333+5DyEIgiAIwiFGnHIEQZha8jzHmTNn+HfbttFsNie4IkEQ9gMZLCcIwtSSpunI71KTIQiHA4l8CIIwtei6jmazicFgAMdxpDNOEA4JIj4EQZhqbNuGbduTXoYgCPuIpF0EQRAEQRgrIj4EQRAEQRgrIj4EQRAEQRgrUvMhCMLMEwQBPM+DaZqoVCqTXo4gCBdBIh+CIMw8eZ4jDEMEQTDppQiCcAmI+BAEYeYxTRMAEMcxsiyb8GoEQbgYIj4EQZh5NE2Drm9lkaMomvBqBEG4GCI+BEE4FFD0Q8SHIEw/Ij4EQTgUiPgQhNlBxIcgCIcCsl6P4/iyJm8LgjB+RHwIgnAoUBQFURQhCAKJfgjClCPiQxCEQ0EcxwjDEL7vIwzDSS9HEIQLICZjgiBMnCiKkCQJoihCmqao1+vQNO2S75/nOTY2NmAYBoIgEPEhCFOORD4EQZgoWZZhfX0dnU6Hoxbdbvey/gbVeBiGAQBot9tS9yEIU4yID0EQJoqiKLsuC4IAaZpe8t9IkgTAlt+HqqpI01TqPgRhihHxIQjCRFEUhVMs5XKZoxeXIx7IZMwwDFQqFZTLZREfgjDFSM2HIAgTZ3Fxkf/f6/W4eNRxnEu6v6ZpWFhYAACUSiV0Oh2EYShD5gRhSpHIhyAIU8WjNQsjv48oimTOiyBMKSI+BGEHSZLA8zz0+32ZkjoBSHwkSXJZdR+EpmmcupGuF0GYTkR8CMI2WZah3W5jdXUV3W4X/X5fNq8DJssyDAaDkc4UVVUftXig6Ie8f4IwnUjNhyBga5PqdDp8pm1ZFnRd57NwYX8ZDodsCAZsCQ7Xdfl627YRxzGGw+HI5ZeKZVkYDAYiPgRhShHxIRx5giDA5uYmAEDXddTrdREdB0y32x2px/A8b0RkOI7DkackSaDrl/dVZZomFEVBmqaP6P6CIBwsknYRjjRRFKHdbgPY2vDm5+dFeEyAnYWhuq6z/we9P5eDoij8Pkr0QxCmDxEfwpElSRJsbm4iz3PYto16vb6n4ZWw/9i2DQBwXRdzc3OYm5vbdZtarQbg3MyWy4XqPqRoWBCmDxEfwpEkTVNsbGwgyzKYpolGoyHCY4yQ+AjDEKZp7jnHxXVdlEolAMBgMHjEjyEtt4IwfYj4ECZOnucYDodje7wsy7C5uYk0TaHrOprNpgiPMWNZFtdkXMjPo1wuA9gSKXEcX9Zj6LoOXdeR57lEPwRhypAqLGFi5HkOz/PgeR7SNIWqqhwq3y+SJIHv+9zKSZNT8zyHpmlotVpQVdHg40ZRFNi2jeFwiCAIzltno2kaHMfBcDiE53mo1+uX/Bh5nnP30uUKF0EQDhYRH8JE2Nnaqmnavk4hHQwGCIIAcRzv+XdVVUWz2bysse3C/lIUH9Vq9by3K5VKGA6HGA6HqFQqu96zKIoQxzG3RxP03iuKwumbo0AcxyzmVVWFpmkc2aP0kwhuYdKI+BDGSpZl6PV67O+gaRoqlQocx9m31Ee/30e/3+ffbdvm7gkaQKbrunwBTxjbtqEoCpIkQbvdhmVZex4HpmnCNE1EUQTf90fmtQwGA/R6Pf792LFjfH+adGua5pFptU3TFGtra7sup2OdxEej0Tjv3Jx2u40oitBqtY7M6yaMHzmyhLExHA5H/B1KpRKq1eq+iY4sy9DpdDi/XyqV4Louu2UK04WiKLAsC0EQcGTD9300m81dwrBUKiGKInieh3K5zMdMUXjshCJeR0lkFgtrVVVFnufI83xXwW273eboCNXckCiPoghpmmJ1dXVEzAnCfiLiQxgLvV6POxYMw0CtVttXP40kSbCxsYE0TaEoCqrV6pEKtc8qpVJppBg0iiIEQbDL1dS2bWiahjRN0ev1uA3Xsixuw52fnx/ZKHee7R8FDMOArutIkgTVahWu6yLLMv7RNA2bm5uI4/i8wq3f78MwDGiahocffhjz8/MsYJIk4Rod27a5cFgQLhcRH8KBkuc52u02bzDlchmVSmVfox0rKyv8u67raDQaEu2YEahOg1Ikqqpyi2wRRVFQq9WwubnJBcqu63KnEtV2FDmK4gPYEmok9l3X5doPol6vY319HXmec0pLURT4vo80TVEul7G+vs633ytyRIXcpmmiXq+zKJQ6KuFSEfEhHBhpmvJZlqIoqNfr580zP5rHKDI3N3ekwuyHgUajgfX1dZimiVardd7b2baNSqXC04apULWYhilCx0GSJIjjmIsvDzumaWJjY4NF3fLy8sj1hmFgYWEBWZaNiPRyuYzBYMBirdfrodVqQVGUkZ8sy6AoCuI4RhRFPIiRUjXHjx9HtVqVz6FwQUR8CAdCHMfspUGdJQdhW24YBqrVKn9Ryhfe7GEYBo4dO3ZJt61UKsiyDJ7nAcAFPUKoqLjb7SKOYzSbTfYNOcxQkWgQBMjzfM/ZNpqm7RJiiqJwMW+tVsPx48cv+DhJkqDT6SCKIpRKJbbB39jYQLlcls+icEFEfAj7ThRF2NzcRJZlbOJ1kFXz5XL5SGw
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
"_track_id = 8701 # random.choice(track_ids)\n",
"_track_id = 3880 # random.choice(track_ids)\n",
"\n",
"# _track_id = 2780\n",
"\n",
"for i in range(100):\n",
" _track_id = random.choice(track_ids)\n",
" plt.plot(\n",
" filtered_data.loc[_track_id,:]['proj_x'],\n",
" filtered_data.loc[_track_id,:]['proj_y'],\n",
" c='grey', alpha=.2\n",
" )\n",
"\n",
"_track_id = random.choice(track_ids)\n",
"# _track_id = 801\n",
"print(_track_id)\n",
"ax = plt.scatter(\n",
" filtered_data.loc[_track_id,:]['proj_x'],\n",
" filtered_data.loc[_track_id,:]['proj_y'],\n",
" marker=\"*\") \n",
"plt.plot(\n",
" filtered_data.loc[_track_id,:]['proj_x'],\n",
" filtered_data.loc[_track_id,:]['proj_y']\n",
")\n",
"\n",
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:10][in_fields].values)\n",
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:50][in_fields].values)\n",
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}