3625 lines
294 KiB
Text
3625 lines
294 KiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/home/ruben/suspicion/trap/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import torch\n",
|
|||
|
"import matplotlib.pyplot as plt # Visualization \n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import pandas_helper_calc # noqa # provides df.calc.derivative()\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import cv2\n",
|
|||
|
"import pathlib\n",
|
|||
|
"from tqdm.autonotebook import tqdm"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"FPS = 12\n",
|
|||
|
"# SRC_CSV = \"EXPERIMENTS/hofext-maskrcnn/all.txt\"\n",
|
|||
|
"# SRC_CSV = \"EXPERIMENTS/raw/generated/train/tracks.txt\"\n",
|
|||
|
"SRC_CSV = \"EXPERIMENTS/raw/hof-meter-maskrcnn2/train/tracks.txt\"\n",
|
|||
|
"SRC_CSV = \"EXPERIMENTS/20240426-hof-yolo/train/tracked.txt\"\n",
|
|||
|
"SRC_CSV = \"EXPERIMENTS/raw/hof2/train/tracked.txt\"\n",
|
|||
|
"# SRC_H = \"../DATASETS/hof/webcam20231103-2-homography.txt\"\n",
|
|||
|
"SRC_H = None\n",
|
|||
|
"CACHE_DIR = \"EXPERIMENTS/cache/hof2/\"\n",
|
|||
|
"SMOOTHING = True # hof-yolo is already smoothed, hof2 isn't\n",
|
|||
|
"SMOOTHING_WINDOW=3 #2"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"in_fields = ['proj_x', 'proj_y', 'vx', 'vy', 'ax', 'ay']\n",
|
|||
|
"# out_fields = ['v', 'heading']\n",
|
|||
|
"# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n",
|
|||
|
"# an we can use simple MSE loss (I guess?)\n",
|
|||
|
"out_fields = ['vx', 'vy']\n",
|
|||
|
"window = int(FPS*1.5)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Set device\n",
|
|||
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|||
|
"print(device)\n",
|
|||
|
"\n",
|
|||
|
"# Hyperparameters\n",
|
|||
|
"input_size = len(in_fields)\n",
|
|||
|
"hidden_size = 256\n",
|
|||
|
"num_layers = 3\n",
|
|||
|
"output_size = len(out_fields)\n",
|
|||
|
"learning_rate = 0.005 #0.01 #0.005\n",
|
|||
|
"batch_size = 256\n",
|
|||
|
"num_epochs = 1000"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"cache_path = pathlib.Path(CACHE_DIR)\n",
|
|||
|
"cache_path.mkdir(parents=True, exist_ok=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Samping 1/5, of 412098 items\n",
|
|||
|
"Done sampling kept 83726 items\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from pathlib import Path\n",
|
|||
|
"from trap.tools import load_tracks_from_csv\n",
|
|||
|
"\n",
|
|||
|
"data = load_tracks_from_csv(Path(SRC_CSV), FPS, 2, 5 )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>dx</th>\n",
|
|||
|
" <th>dy</th>\n",
|
|||
|
" <th>vx</th>\n",
|
|||
|
" <th>vy</th>\n",
|
|||
|
" <th>ax</th>\n",
|
|||
|
" <th>ay</th>\n",
|
|||
|
" <th>v</th>\n",
|
|||
|
" <th>a</th>\n",
|
|||
|
" <th>heading</th>\n",
|
|||
|
" <th>d_heading</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>194</th>\n",
|
|||
|
" <td>606.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1593.885864</td>\n",
|
|||
|
" <td>782.814819</td>\n",
|
|||
|
" <td>145.704346</td>\n",
|
|||
|
" <td>195.380432</td>\n",
|
|||
|
" <td>12.897830</td>\n",
|
|||
|
" <td>10.750061</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.201965</td>\n",
|
|||
|
" <td>-0.291350</td>\n",
|
|||
|
" <td>0.484716</td>\n",
|
|||
|
" <td>-0.699240</td>\n",
|
|||
|
" <td>-1.622919</td>\n",
|
|||
|
" <td>-1.732144</td>\n",
|
|||
|
" <td>0.850815</td>\n",
|
|||
|
" <td>1.399195</td>\n",
|
|||
|
" <td>304.729842</td>\n",
|
|||
|
" <td>-101.772559</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>199</th>\n",
|
|||
|
" <td>611.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1563.890015</td>\n",
|
|||
|
" <td>700.710510</td>\n",
|
|||
|
" <td>137.461304</td>\n",
|
|||
|
" <td>190.194855</td>\n",
|
|||
|
" <td>13.099794</td>\n",
|
|||
|
" <td>10.458712</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.201965</td>\n",
|
|||
|
" <td>-0.291350</td>\n",
|
|||
|
" <td>0.484716</td>\n",
|
|||
|
" <td>-0.699240</td>\n",
|
|||
|
" <td>-1.622919</td>\n",
|
|||
|
" <td>-1.732144</td>\n",
|
|||
|
" <td>0.850815</td>\n",
|
|||
|
" <td>1.399195</td>\n",
|
|||
|
" <td>304.729842</td>\n",
|
|||
|
" <td>-101.772559</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>204</th>\n",
|
|||
|
" <td>616.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1529.469727</td>\n",
|
|||
|
" <td>635.622498</td>\n",
|
|||
|
" <td>129.342651</td>\n",
|
|||
|
" <td>194.191528</td>\n",
|
|||
|
" <td>13.020002</td>\n",
|
|||
|
" <td>9.866642</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.079792</td>\n",
|
|||
|
" <td>-0.592069</td>\n",
|
|||
|
" <td>-0.191501</td>\n",
|
|||
|
" <td>-1.420966</td>\n",
|
|||
|
" <td>-1.622919</td>\n",
|
|||
|
" <td>-1.732144</td>\n",
|
|||
|
" <td>1.433812</td>\n",
|
|||
|
" <td>1.399195</td>\n",
|
|||
|
" <td>262.324609</td>\n",
|
|||
|
" <td>-101.772559</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>209</th>\n",
|
|||
|
" <td>621.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1474.449341</td>\n",
|
|||
|
" <td>569.387634</td>\n",
|
|||
|
" <td>128.099854</td>\n",
|
|||
|
" <td>199.766357</td>\n",
|
|||
|
" <td>12.965776</td>\n",
|
|||
|
" <td>9.301442</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.054226</td>\n",
|
|||
|
" <td>-0.565200</td>\n",
|
|||
|
" <td>-0.130143</td>\n",
|
|||
|
" <td>-1.356479</td>\n",
|
|||
|
" <td>0.147259</td>\n",
|
|||
|
" <td>0.154769</td>\n",
|
|||
|
" <td>1.362708</td>\n",
|
|||
|
" <td>-0.170650</td>\n",
|
|||
|
" <td>264.519715</td>\n",
|
|||
|
" <td>5.268254</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>214</th>\n",
|
|||
|
" <td>626.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1443.123535</td>\n",
|
|||
|
" <td>518.907043</td>\n",
|
|||
|
" <td>120.022461</td>\n",
|
|||
|
" <td>202.566772</td>\n",
|
|||
|
" <td>12.642992</td>\n",
|
|||
|
" <td>8.976624</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.322784</td>\n",
|
|||
|
" <td>-0.324818</td>\n",
|
|||
|
" <td>-0.774681</td>\n",
|
|||
|
" <td>-0.779564</td>\n",
|
|||
|
" <td>-1.546892</td>\n",
|
|||
|
" <td>1.384597</td>\n",
|
|||
|
" <td>1.099023</td>\n",
|
|||
|
" <td>-0.632844</td>\n",
|
|||
|
" <td>225.179993</td>\n",
|
|||
|
" <td>-94.415332</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>219</th>\n",
|
|||
|
" <td>631.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1398.944946</td>\n",
|
|||
|
" <td>461.813049</td>\n",
|
|||
|
" <td>106.391357</td>\n",
|
|||
|
" <td>193.476410</td>\n",
|
|||
|
" <td>12.465588</td>\n",
|
|||
|
" <td>8.557788</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.177404</td>\n",
|
|||
|
" <td>-0.418836</td>\n",
|
|||
|
" <td>-0.425771</td>\n",
|
|||
|
" <td>-1.005205</td>\n",
|
|||
|
" <td>0.837386</td>\n",
|
|||
|
" <td>-0.541539</td>\n",
|
|||
|
" <td>1.091659</td>\n",
|
|||
|
" <td>-0.017675</td>\n",
|
|||
|
" <td>247.044148</td>\n",
|
|||
|
" <td>52.473972</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>224</th>\n",
|
|||
|
" <td>636.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1353.237793</td>\n",
|
|||
|
" <td>438.118896</td>\n",
|
|||
|
" <td>91.444336</td>\n",
|
|||
|
" <td>170.930664</td>\n",
|
|||
|
" <td>12.128433</td>\n",
|
|||
|
" <td>8.052323</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.337155</td>\n",
|
|||
|
" <td>-0.505465</td>\n",
|
|||
|
" <td>-0.809172</td>\n",
|
|||
|
" <td>-1.213117</td>\n",
|
|||
|
" <td>-0.920163</td>\n",
|
|||
|
" <td>-0.498987</td>\n",
|
|||
|
" <td>1.458222</td>\n",
|
|||
|
" <td>0.879752</td>\n",
|
|||
|
" <td>236.295957</td>\n",
|
|||
|
" <td>-25.795658</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>229</th>\n",
|
|||
|
" <td>641.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1272.791992</td>\n",
|
|||
|
" <td>408.827759</td>\n",
|
|||
|
" <td>104.274536</td>\n",
|
|||
|
" <td>180.414551</td>\n",
|
|||
|
" <td>11.689648</td>\n",
|
|||
|
" <td>7.684636</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.438785</td>\n",
|
|||
|
" <td>-0.367687</td>\n",
|
|||
|
" <td>-1.053084</td>\n",
|
|||
|
" <td>-0.882448</td>\n",
|
|||
|
" <td>-0.585388</td>\n",
|
|||
|
" <td>0.793604</td>\n",
|
|||
|
" <td>1.373936</td>\n",
|
|||
|
" <td>-0.202286</td>\n",
|
|||
|
" <td>219.961870</td>\n",
|
|||
|
" <td>-39.201809</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>234</th>\n",
|
|||
|
" <td>646.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1198.965820</td>\n",
|
|||
|
" <td>407.952759</td>\n",
|
|||
|
" <td>103.282104</td>\n",
|
|||
|
" <td>167.306580</td>\n",
|
|||
|
" <td>11.207276</td>\n",
|
|||
|
" <td>7.476216</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.482372</td>\n",
|
|||
|
" <td>-0.208420</td>\n",
|
|||
|
" <td>-1.157693</td>\n",
|
|||
|
" <td>-0.500209</td>\n",
|
|||
|
" <td>-0.251064</td>\n",
|
|||
|
" <td>0.917374</td>\n",
|
|||
|
" <td>1.261136</td>\n",
|
|||
|
" <td>-0.270721</td>\n",
|
|||
|
" <td>203.367915</td>\n",
|
|||
|
" <td>-39.825493</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>239</th>\n",
|
|||
|
" <td>651.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1156.309570</td>\n",
|
|||
|
" <td>415.743408</td>\n",
|
|||
|
" <td>97.628784</td>\n",
|
|||
|
" <td>158.774811</td>\n",
|
|||
|
" <td>10.884154</td>\n",
|
|||
|
" <td>7.514692</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.323122</td>\n",
|
|||
|
" <td>0.038476</td>\n",
|
|||
|
" <td>-0.775493</td>\n",
|
|||
|
" <td>0.092343</td>\n",
|
|||
|
" <td>0.917282</td>\n",
|
|||
|
" <td>1.422125</td>\n",
|
|||
|
" <td>0.780971</td>\n",
|
|||
|
" <td>-1.152395</td>\n",
|
|||
|
" <td>173.209381</td>\n",
|
|||
|
" <td>-72.380481</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>244</th>\n",
|
|||
|
" <td>656.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1094.440430</td>\n",
|
|||
|
" <td>443.849915</td>\n",
|
|||
|
" <td>107.938110</td>\n",
|
|||
|
" <td>177.703979</td>\n",
|
|||
|
" <td>10.544492</td>\n",
|
|||
|
" <td>7.870090</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.339661</td>\n",
|
|||
|
" <td>0.355398</td>\n",
|
|||
|
" <td>-0.815187</td>\n",
|
|||
|
" <td>0.852955</td>\n",
|
|||
|
" <td>-0.095267</td>\n",
|
|||
|
" <td>1.825468</td>\n",
|
|||
|
" <td>1.179857</td>\n",
|
|||
|
" <td>0.957326</td>\n",
|
|||
|
" <td>133.703018</td>\n",
|
|||
|
" <td>-94.815270</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>249</th>\n",
|
|||
|
" <td>661.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1072.595093</td>\n",
|
|||
|
" <td>481.461945</td>\n",
|
|||
|
" <td>118.452148</td>\n",
|
|||
|
" <td>205.365173</td>\n",
|
|||
|
" <td>10.486504</td>\n",
|
|||
|
" <td>8.287758</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.057989</td>\n",
|
|||
|
" <td>0.417668</td>\n",
|
|||
|
" <td>-0.139173</td>\n",
|
|||
|
" <td>1.002404</td>\n",
|
|||
|
" <td>1.622435</td>\n",
|
|||
|
" <td>0.358678</td>\n",
|
|||
|
" <td>1.012019</td>\n",
|
|||
|
" <td>-0.402811</td>\n",
|
|||
|
" <td>97.904355</td>\n",
|
|||
|
" <td>-85.916792</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>254</th>\n",
|
|||
|
" <td>666.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1086.627930</td>\n",
|
|||
|
" <td>526.733154</td>\n",
|
|||
|
" <td>105.444458</td>\n",
|
|||
|
" <td>189.750610</td>\n",
|
|||
|
" <td>10.498393</td>\n",
|
|||
|
" <td>8.684043</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.011889</td>\n",
|
|||
|
" <td>0.396285</td>\n",
|
|||
|
" <td>0.028534</td>\n",
|
|||
|
" <td>0.951083</td>\n",
|
|||
|
" <td>0.402496</td>\n",
|
|||
|
" <td>-0.123170</td>\n",
|
|||
|
" <td>0.951511</td>\n",
|
|||
|
" <td>-0.145220</td>\n",
|
|||
|
" <td>88.281546</td>\n",
|
|||
|
" <td>-23.094741</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>259</th>\n",
|
|||
|
" <td>671.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1099.592285</td>\n",
|
|||
|
" <td>584.216675</td>\n",
|
|||
|
" <td>114.395874</td>\n",
|
|||
|
" <td>218.003479</td>\n",
|
|||
|
" <td>10.492767</td>\n",
|
|||
|
" <td>9.267106</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.005626</td>\n",
|
|||
|
" <td>0.583063</td>\n",
|
|||
|
" <td>-0.013502</td>\n",
|
|||
|
" <td>1.399352</td>\n",
|
|||
|
" <td>-0.100887</td>\n",
|
|||
|
" <td>1.075845</td>\n",
|
|||
|
" <td>1.399417</td>\n",
|
|||
|
" <td>1.074975</td>\n",
|
|||
|
" <td>90.552815</td>\n",
|
|||
|
" <td>5.451045</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>264</th>\n",
|
|||
|
" <td>676.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1144.484782</td>\n",
|
|||
|
" <td>642.779582</td>\n",
|
|||
|
" <td>96.750326</td>\n",
|
|||
|
" <td>180.744690</td>\n",
|
|||
|
" <td>10.484691</td>\n",
|
|||
|
" <td>9.582745</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>-0.008077</td>\n",
|
|||
|
" <td>0.315639</td>\n",
|
|||
|
" <td>-0.019384</td>\n",
|
|||
|
" <td>0.757534</td>\n",
|
|||
|
" <td>-0.014116</td>\n",
|
|||
|
" <td>-1.540364</td>\n",
|
|||
|
" <td>0.757782</td>\n",
|
|||
|
" <td>-1.539925</td>\n",
|
|||
|
" <td>91.465753</td>\n",
|
|||
|
" <td>2.191052</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>269</th>\n",
|
|||
|
" <td>681.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>1179.532959</td>\n",
|
|||
|
" <td>682.365540</td>\n",
|
|||
|
" <td>107.764282</td>\n",
|
|||
|
" <td>200.651733</td>\n",
|
|||
|
" <td>10.698373</td>\n",
|
|||
|
" <td>9.950516</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.213682</td>\n",
|
|||
|
" <td>0.367771</td>\n",
|
|||
|
" <td>0.512837</td>\n",
|
|||
|
" <td>0.882650</td>\n",
|
|||
|
" <td>1.277331</td>\n",
|
|||
|
" <td>0.300278</td>\n",
|
|||
|
" <td>1.020820</td>\n",
|
|||
|
" <td>0.631291</td>\n",
|
|||
|
" <td>59.842534</td>\n",
|
|||
|
" <td>-75.895726</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>16 rows × 24 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"194 606.0 4 1593.885864 782.814819 145.704346 195.380432 \n",
|
|||
|
"199 611.0 4 1563.890015 700.710510 137.461304 190.194855 \n",
|
|||
|
"204 616.0 4 1529.469727 635.622498 129.342651 194.191528 \n",
|
|||
|
"209 621.0 4 1474.449341 569.387634 128.099854 199.766357 \n",
|
|||
|
"214 626.0 4 1443.123535 518.907043 120.022461 202.566772 \n",
|
|||
|
"219 631.0 4 1398.944946 461.813049 106.391357 193.476410 \n",
|
|||
|
"224 636.0 4 1353.237793 438.118896 91.444336 170.930664 \n",
|
|||
|
"229 641.0 4 1272.791992 408.827759 104.274536 180.414551 \n",
|
|||
|
"234 646.0 4 1198.965820 407.952759 103.282104 167.306580 \n",
|
|||
|
"239 651.0 4 1156.309570 415.743408 97.628784 158.774811 \n",
|
|||
|
"244 656.0 4 1094.440430 443.849915 107.938110 177.703979 \n",
|
|||
|
"249 661.0 4 1072.595093 481.461945 118.452148 205.365173 \n",
|
|||
|
"254 666.0 4 1086.627930 526.733154 105.444458 189.750610 \n",
|
|||
|
"259 671.0 4 1099.592285 584.216675 114.395874 218.003479 \n",
|
|||
|
"264 676.0 4 1144.484782 642.779582 96.750326 180.744690 \n",
|
|||
|
"269 681.0 4 1179.532959 682.365540 107.764282 200.651733 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff ... dx dy vx \\\n",
|
|||
|
"194 12.897830 10.750061 2.0 NaN ... 0.201965 -0.291350 0.484716 \n",
|
|||
|
"199 13.099794 10.458712 1.0 5.0 ... 0.201965 -0.291350 0.484716 \n",
|
|||
|
"204 13.020002 9.866642 2.0 5.0 ... -0.079792 -0.592069 -0.191501 \n",
|
|||
|
"209 12.965776 9.301442 2.0 5.0 ... -0.054226 -0.565200 -0.130143 \n",
|
|||
|
"214 12.642992 8.976624 2.0 5.0 ... -0.322784 -0.324818 -0.774681 \n",
|
|||
|
"219 12.465588 8.557788 2.0 5.0 ... -0.177404 -0.418836 -0.425771 \n",
|
|||
|
"224 12.128433 8.052323 2.0 5.0 ... -0.337155 -0.505465 -0.809172 \n",
|
|||
|
"229 11.689648 7.684636 2.0 5.0 ... -0.438785 -0.367687 -1.053084 \n",
|
|||
|
"234 11.207276 7.476216 2.0 5.0 ... -0.482372 -0.208420 -1.157693 \n",
|
|||
|
"239 10.884154 7.514692 2.0 5.0 ... -0.323122 0.038476 -0.775493 \n",
|
|||
|
"244 10.544492 7.870090 2.0 5.0 ... -0.339661 0.355398 -0.815187 \n",
|
|||
|
"249 10.486504 8.287758 2.0 5.0 ... -0.057989 0.417668 -0.139173 \n",
|
|||
|
"254 10.498393 8.684043 2.0 5.0 ... 0.011889 0.396285 0.028534 \n",
|
|||
|
"259 10.492767 9.267106 2.0 5.0 ... -0.005626 0.583063 -0.013502 \n",
|
|||
|
"264 10.484691 9.582745 1.0 5.0 ... -0.008077 0.315639 -0.019384 \n",
|
|||
|
"269 10.698373 9.950516 2.0 5.0 ... 0.213682 0.367771 0.512837 \n",
|
|||
|
"\n",
|
|||
|
" vy ax ay v a heading d_heading \n",
|
|||
|
"194 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
|
|||
|
"199 -0.699240 -1.622919 -1.732144 0.850815 1.399195 304.729842 -101.772559 \n",
|
|||
|
"204 -1.420966 -1.622919 -1.732144 1.433812 1.399195 262.324609 -101.772559 \n",
|
|||
|
"209 -1.356479 0.147259 0.154769 1.362708 -0.170650 264.519715 5.268254 \n",
|
|||
|
"214 -0.779564 -1.546892 1.384597 1.099023 -0.632844 225.179993 -94.415332 \n",
|
|||
|
"219 -1.005205 0.837386 -0.541539 1.091659 -0.017675 247.044148 52.473972 \n",
|
|||
|
"224 -1.213117 -0.920163 -0.498987 1.458222 0.879752 236.295957 -25.795658 \n",
|
|||
|
"229 -0.882448 -0.585388 0.793604 1.373936 -0.202286 219.961870 -39.201809 \n",
|
|||
|
"234 -0.500209 -0.251064 0.917374 1.261136 -0.270721 203.367915 -39.825493 \n",
|
|||
|
"239 0.092343 0.917282 1.422125 0.780971 -1.152395 173.209381 -72.380481 \n",
|
|||
|
"244 0.852955 -0.095267 1.825468 1.179857 0.957326 133.703018 -94.815270 \n",
|
|||
|
"249 1.002404 1.622435 0.358678 1.012019 -0.402811 97.904355 -85.916792 \n",
|
|||
|
"254 0.951083 0.402496 -0.123170 0.951511 -0.145220 88.281546 -23.094741 \n",
|
|||
|
"259 1.399352 -0.100887 1.075845 1.399417 1.074975 90.552815 5.451045 \n",
|
|||
|
"264 0.757534 -0.014116 -1.540364 0.757782 -1.539925 91.465753 2.191052 \n",
|
|||
|
"269 0.882650 1.277331 0.300278 1.020820 0.631291 59.842534 -75.895726 \n",
|
|||
|
"\n",
|
|||
|
"[16 rows x 24 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th rowspan=\"5\" valign=\"top\">1</th>\n",
|
|||
|
" <th>342</th>\n",
|
|||
|
" <td>1393.736572</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>67.613647</td>\n",
|
|||
|
" <td>121.391151</td>\n",
|
|||
|
" <td>1363.3164</td>\n",
|
|||
|
" <td>232.92647</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>343</th>\n",
|
|||
|
" <td>1391.775879</td>\n",
|
|||
|
" <td>0.852371</td>\n",
|
|||
|
" <td>78.562622</td>\n",
|
|||
|
" <td>141.050934</td>\n",
|
|||
|
" <td>1359.1885</td>\n",
|
|||
|
" <td>266.06586</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>346</th>\n",
|
|||
|
" <td>1392.164551</td>\n",
|
|||
|
" <td>7.758987</td>\n",
|
|||
|
" <td>85.757324</td>\n",
|
|||
|
" <td>154.357971</td>\n",
|
|||
|
" <td>1355.7444</td>\n",
|
|||
|
" <td>297.67404</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>347</th>\n",
|
|||
|
" <td>1393.844849</td>\n",
|
|||
|
" <td>12.691238</td>\n",
|
|||
|
" <td>86.482910</td>\n",
|
|||
|
" <td>156.264786</td>\n",
|
|||
|
" <td>1355.2312</td>\n",
|
|||
|
" <td>308.20670</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>348</th>\n",
|
|||
|
" <td>1394.839111</td>\n",
|
|||
|
" <td>15.621338</td>\n",
|
|||
|
" <td>84.763428</td>\n",
|
|||
|
" <td>154.584396</td>\n",
|
|||
|
" <td>1354.9246</td>\n",
|
|||
|
" <td>310.09225</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th rowspan=\"5\" valign=\"top\">5030</th>\n",
|
|||
|
" <th>32691</th>\n",
|
|||
|
" <td>1708.213379</td>\n",
|
|||
|
" <td>749.260376</td>\n",
|
|||
|
" <td>133.839966</td>\n",
|
|||
|
" <td>182.405396</td>\n",
|
|||
|
" <td>1402.5426</td>\n",
|
|||
|
" <td>1075.20870</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32692</th>\n",
|
|||
|
" <td>1707.651855</td>\n",
|
|||
|
" <td>748.997437</td>\n",
|
|||
|
" <td>134.013672</td>\n",
|
|||
|
" <td>182.391296</td>\n",
|
|||
|
" <td>1402.2948</td>\n",
|
|||
|
" <td>1074.97230</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32720</th>\n",
|
|||
|
" <td>1700.379639</td>\n",
|
|||
|
" <td>750.314697</td>\n",
|
|||
|
" <td>128.792603</td>\n",
|
|||
|
" <td>181.589783</td>\n",
|
|||
|
" <td>1395.7992</td>\n",
|
|||
|
" <td>1074.27320</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32721</th>\n",
|
|||
|
" <td>1701.722412</td>\n",
|
|||
|
" <td>751.000488</td>\n",
|
|||
|
" <td>125.286865</td>\n",
|
|||
|
" <td>180.867615</td>\n",
|
|||
|
" <td>1395.5424</td>\n",
|
|||
|
" <td>1074.20560</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32722</th>\n",
|
|||
|
" <td>1702.384766</td>\n",
|
|||
|
" <td>750.754517</td>\n",
|
|||
|
" <td>123.435425</td>\n",
|
|||
|
" <td>180.945618</td>\n",
|
|||
|
" <td>1395.4082</td>\n",
|
|||
|
" <td>1074.06500</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>326960 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" l t w h x \\\n",
|
|||
|
"track_id frame_id \n",
|
|||
|
"1 342 1393.736572 0.000000 67.613647 121.391151 1363.3164 \n",
|
|||
|
" 343 1391.775879 0.852371 78.562622 141.050934 1359.1885 \n",
|
|||
|
" 346 1392.164551 7.758987 85.757324 154.357971 1355.7444 \n",
|
|||
|
" 347 1393.844849 12.691238 86.482910 156.264786 1355.2312 \n",
|
|||
|
" 348 1394.839111 15.621338 84.763428 154.584396 1354.9246 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"5030 32691 1708.213379 749.260376 133.839966 182.405396 1402.5426 \n",
|
|||
|
" 32692 1707.651855 748.997437 134.013672 182.391296 1402.2948 \n",
|
|||
|
" 32720 1700.379639 750.314697 128.792603 181.589783 1395.7992 \n",
|
|||
|
" 32721 1701.722412 751.000488 125.286865 180.867615 1395.5424 \n",
|
|||
|
" 32722 1702.384766 750.754517 123.435425 180.945618 1395.4082 \n",
|
|||
|
"\n",
|
|||
|
" y state \n",
|
|||
|
"track_id frame_id \n",
|
|||
|
"1 342 232.92647 2 \n",
|
|||
|
" 343 266.06586 2 \n",
|
|||
|
" 346 297.67404 2 \n",
|
|||
|
" 347 308.20670 2 \n",
|
|||
|
" 348 310.09225 2 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"5030 32691 1075.20870 2 \n",
|
|||
|
" 32692 1074.97230 2 \n",
|
|||
|
" 32720 1074.27320 2 \n",
|
|||
|
" 32721 1074.20560 2 \n",
|
|||
|
" 32722 1074.06500 2 \n",
|
|||
|
"\n",
|
|||
|
"[326960 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"data = pd.read_csv(SRC_CSV, delimiter=\"\\t\", index_col=False, header=None)\n",
|
|||
|
"# data.columns = ['frame_id', 'track_id', 'pos_x', 'pos_y', 'width', 'height']#, '_x', '_y,']\n",
|
|||
|
"data.columns = ['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state']#, '_x', '_y,']\n",
|
|||
|
"data['frame_id'] = pd.to_numeric(data['frame_id'], downcast='integer')\n",
|
|||
|
"data['frame_id'] = data['frame_id'] // 10 # compatibility with Trajectron++\n",
|
|||
|
"\n",
|
|||
|
"data.sort_values(by=['track_id', 'frame_id'],inplace=True)\n",
|
|||
|
"\n",
|
|||
|
"data.set_index(['track_id', 'frame_id'])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# cm to meter\n",
|
|||
|
"data['x'] = data['x']/100\n",
|
|||
|
"data['y'] = data['y']/100"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff() #.fillna(0)\n",
|
|||
|
"data['diff'] = pd.to_numeric(data['diff'], downcast='integer')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"326960it [06:37, 821.55it/s] "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"was: 326960 added: 85138 new length: 412098\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"missing=0\n",
|
|||
|
"old_size=len(data)\n",
|
|||
|
"# slow way to append missing steps to the dataset\n",
|
|||
|
"for ind, row in tqdm(data.iterrows()):\n",
|
|||
|
" if row['diff'] > 1:\n",
|
|||
|
" for s in range(1, int(row['diff'])):\n",
|
|||
|
" # add as many entries as missing\n",
|
|||
|
" missing += 1\n",
|
|||
|
" data.loc[len(data)] = [row['frame_id']-s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, 1]\n",
|
|||
|
" # new_frame = [data.loc[ind-1]['frame_id']+s, row['track_id'], np.nan, np.nan, np.nan, np.nan, np.nan]\n",
|
|||
|
" # data.loc[len(data)] = new_frame\n",
|
|||
|
"\n",
|
|||
|
"print('was:', old_size, 'added:', missing, 'new length:', len(data))\n",
|
|||
|
"# now sort, so that the added data is in the right place\n",
|
|||
|
"data.sort_values(by=['track_id', 'frame_id'], inplace=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# interpolate missing data\n",
|
|||
|
"df=data.copy()\n",
|
|||
|
"df = df.groupby('track_id').apply(lambda group: group.interpolate(method='linear'))\n",
|
|||
|
"df.reset_index(drop=True, inplace=True)\n",
|
|||
|
"data = df\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Running smoother\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from trap.tracker import Smoother\n",
|
|||
|
"\n",
|
|||
|
"if SMOOTHING:\n",
|
|||
|
" df=data.copy()\n",
|
|||
|
" if 'x_raw' not in df:\n",
|
|||
|
" df['x_raw'] = df['x']\n",
|
|||
|
" if 'y_raw' not in df:\n",
|
|||
|
" df['y_raw'] = df['y']\n",
|
|||
|
"\n",
|
|||
|
" print(\"Running smoother\")\n",
|
|||
|
" # print(df)\n",
|
|||
|
" # from tsmoothie.smoother import KalmanSmoother, ConvolutionSmoother\n",
|
|||
|
" smoother = Smoother(convolution=False)\n",
|
|||
|
" def smoothing(data):\n",
|
|||
|
" # smoother = ConvolutionSmoother(window_len=SMOOTHING_WINDOW, window_type='ones', copy=None)\n",
|
|||
|
" return smoother.smooth(data).tolist()\n",
|
|||
|
" # df=df.assign(smooth_data=smoother.smooth_data[0])\n",
|
|||
|
" # return smoother.smooth_data[0].tolist()\n",
|
|||
|
"\n",
|
|||
|
" # operate smoothing per axis\n",
|
|||
|
" df['x'] = df.groupby('track_id')['x_raw'].transform(smoothing)\n",
|
|||
|
" df['y'] = df.groupby('track_id')['y_raw'].transform(smoothing)\n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
" data = df\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>x_raw</th>\n",
|
|||
|
" <th>y_raw</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0.881595</td>\n",
|
|||
|
" <td>7.341152</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.870703</td>\n",
|
|||
|
" <td>7.309168</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.901374</td>\n",
|
|||
|
" <td>7.370044</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.924360</td>\n",
|
|||
|
" <td>7.432365</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.906583</td>\n",
|
|||
|
" <td>7.456334</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632.0</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>15.214551</td>\n",
|
|||
|
" <td>10.027093</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632.0</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.244872</td>\n",
|
|||
|
" <td>10.047117</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632.0</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.318496</td>\n",
|
|||
|
" <td>10.015218</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632.0</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.400203</td>\n",
|
|||
|
" <td>9.935355</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632.0</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.416893</td>\n",
|
|||
|
" <td>10.051785</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 12 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1.0 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1.0 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1.0 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1.0 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1.0 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632.0 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632.0 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632.0 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632.0 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632.0 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff x_raw y_raw \n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 12 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# del data['diff']\n",
|
|||
|
"# recalculate diff\n",
|
|||
|
"data['diff'] = data.groupby(['track_id'])['frame_id'].diff()\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>x_raw</th>\n",
|
|||
|
" <th>y_raw</th>\n",
|
|||
|
" <th>dt</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0.881595</td>\n",
|
|||
|
" <td>7.341152</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.870703</td>\n",
|
|||
|
" <td>7.309168</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.901374</td>\n",
|
|||
|
" <td>7.370044</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.924360</td>\n",
|
|||
|
" <td>7.432365</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.906583</td>\n",
|
|||
|
" <td>7.456334</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>15.214551</td>\n",
|
|||
|
" <td>10.027093</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.244872</td>\n",
|
|||
|
" <td>10.047117</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.318496</td>\n",
|
|||
|
" <td>10.015218</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.400203</td>\n",
|
|||
|
" <td>9.935355</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.416893</td>\n",
|
|||
|
" <td>10.051785</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 13 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff x_raw y_raw dt \n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 13 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"# data['node_type'] = 'PEDESTRIAN' # compatibility with Trajectron++\n",
|
|||
|
"# data['node_id'] = data['track_id'].astype(str)\n",
|
|||
|
"data['track_id'] = pd.to_numeric(data['track_id'], downcast='integer')\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"data['dt'] = data['diff'] * (1/FPS)\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# position into an average coordinate system. (DO THESE NEED TO BE STORED?)\n",
|
|||
|
"# Don't do this, messes up\n",
|
|||
|
"# data['pos_x'] = data['pos_x'] - data['pos_x'].mean()\n",
|
|||
|
"# data['pos_y'] = data['pos_y'] - data['pos_y'].mean()\n",
|
|||
|
" \n",
|
|||
|
"# data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<AxesSubplot:>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAGdCAYAAAD+JxxnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxGUlEQVR4nO3dfVTUdd7/8RcgDFIO3gXIJRplpZRK4UbTrTfIqJxObm5rN8fIvDm60FnkXFqU4V277uXmXYVx2lLaU5a6p9xSL2TCVStHTZQrtfTqxi53Tw12o6KYwwjz+6Mf35wwZVydCT7Pxzmcs/P9vuc773nzYX31ne8XIvx+v18AAAAGigx3AwAAAOFCEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGKtduBv4JWtsbNSXX36pDh06KCIiItztAACAFvD7/Tp27JiSk5MVGXn2cz4EobP48ssvlZKSEu42AADAefjnP/+p7t27n7WGIHQWHTp0kPTDIO12e5i7CT+fz6eKigplZ2crOjo63O20Wcw5NJhz6DDr0GDOP6qtrVVKSor17/jZEITOounjMLvdThDSDz9kcXFxstvtxv+QXUzMOTSYc+gw69Bgzs215LIWLpYGAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMFa7cDcAAOF03cz18jZEhLuNFvviTznhbgFoUzgjBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADBWUEHo+eefV79+/WS322W32+VwOPTf//3f1v6TJ08qLy9PXbp00aWXXqpRo0appqYm4BgHDx5UTk6O4uLilJCQoKlTp+rUqVMBNRs3btQNN9wgm82mXr16qaysrFkvJSUluvzyyxUbG6vMzExt3749YH9LegEAAGYLKgh1795df/rTn1RVVaUdO3Zo8ODBuuuuu7R3715J0pQpU/T2229r1apV2rRpk7788kvdfffd1vMbGhqUk5Oj+vp6bdmyRS+//LLKyspUXFxs1Rw4cEA5OTkaNGiQqqurVVBQoPHjx2v9+vVWzYoVK1RYWKgZM2Zo586d6t+/v5xOpw4dOmTVnKsXAACAoILQnXfeqREjRuiqq67S1VdfrT/84Q+69NJLtXXrVh09elQvvfSSFixYoMGDBysjI0PLli3Tli1btHXrVklSRUWFPvroI73yyitKT0/X8OHDNWfOHJWUlKi+vl6SVFpaqtTUVM2fP199+vRRfn6+fvOb32jhwoVWHwsWLNCECRM0duxYpaWlqbS0VHFxcVq6dKkktagXAACAduf7xIaGBq1atUp1dXVyOByqqqqSz+dTVlaWVdO7d2/16NFDbrdbN910k9xut/r27avExESrxul0avLkydq7d6+uv/56ud3ugGM01RQUFEiS6uvrVVVVpaKiImt/ZGSksrKy5Ha7JalFvZyJ1+uV1+u1HtfW1kqSfD6ffD7feU6q7WiaAbO4uJhzaDTN1xbpD3MnwWmN64I1HRrM+UfBzCDoILR79245HA6dPHlSl156qd58802lpaWpurpaMTEx6tixY0B9YmKiPB6PJMnj8QSEoKb9TfvOVlNbW6vvv/9ehw8fVkNDwxlr9u3bZx3jXL2cydy5czVr1qxm2ysqKhQXF/ezzzONy+UKdwtGYM6hMWdAY7hbCMq6devC3cJ5Y02HBnOWTpw40eLaoIPQNddco+rqah09elR/+9vflJubq02bNgV7mF+koqIiFRYWWo9ra2uVkpKi7Oxs2e32MHb2y+Dz+eRyuTR06FBFR0eHu502izmHRtOcn9wRKW9jRLjbabE9M53hbiForOnQYM4/avpEpyWCDkIxMTHq1auXJCkjI0MffPCBFi9erNGjR6u+vl5HjhwJOBNTU1OjpKQkSVJSUlKzu7ua7uQ6veand3fV1NTIbrerffv2ioqKUlRU1BlrTj/GuXo5E5vNJpvN1mx7dHS08YvqdMwjNJhzaHgbI+RtaD1BqDWvCdZ0aDDn4H5O/u3fI9TY2Civ16uMjAxFR0ersrLS2rd//34dPHhQDodDkuRwOLR79+6Au7tcLpfsdrvS0tKsmtOP0VTTdIyYmBhlZGQE1DQ2NqqystKqaUkvAAAAQZ0RKioq0vDhw9WjRw8dO3ZMy5cv18aNG7V+/XrFx8dr3LhxKiwsVOfOnWW32/XII4/I4XBYFydnZ2crLS1NY8aM0bx58+TxeDR9+nTl5eVZZ2ImTZqk5557TtOmTdPDDz+sDRs2aOXKlVq7dq3VR2FhoXJzczVgwADdeOONWrRokerq6jR27FhJalEvAAAAQQWhQ4cO6cEHH9RXX32l+Ph49evXT+vXr9fQoUMlSQsXLlRkZKRGjRolr9crp9OpJUuWWM+PiorSmjVrNHnyZDkcDl1yySXKzc3V7NmzrZrU1FStXbtWU6ZM0eLFi9W9e3e9+OKLcjp//Fx89OjR+vrrr1VcXCyPx6P09HSVl5cHXEB9rl4AAAAi/H5/67p3NIRqa2sVHx+vo0ePcrG0frgQb926dRoxYoTxnz9fTMw5NJrmPG17VKu6RuiLP+WEu4WgsaZDgzn/KJh/v/lbYwAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABiLIAQAAIxFEAIAAMYiCAEAAGMRhAAAgLEIQgAAwFgEIQAAYCyCEAAAMBZBCAAAGIsgBAAAjEUQAgAAxiIIAQAAYxGEAACAsQhCAADAWAQhAABgLIIQAAAwFkEIAAAYiyAEAACMRRACAADGIggBAABjEYQAAICxCEIAAMBYBCEAAGAsghAAADAWQQgAABgrqCA0d+5c/epXv1KHDh2UkJCgkSNHav/+/QE1AwcOVERERMDXpEmTAmoOHjyonJwcxcXFKSEhQVOnTtWpU6cCajZu3KgbbrhBNptNvXr1UllZWbN+SkpKdPnllys2NlaZmZnavn17wP6TJ08qLy9PXbp00aWXXqpRo0appqYmmLcMAADasKCC0KZNm5SXl6etW7fK5XLJ5/MpOztbdXV1AXUTJkzQV199ZX3NmzfP2tfQ0KCcnBzV19dry5Ytevnll1VWVqbi4mKr5sCBA8rJydGgQYNUXV2tgoICjR8/XuvXr7dqVqxYocLCQs2YMUM7d+5U//795XQ6dejQIatmypQpevvtt7Vq1Spt2rRJX375pe6+++6ghwQAANqmdsEUl5eXBzwuKytTQkKCqqqqdPvtt1vb4+LilJSUdMZjVFRU6KOPPtI777yjxMREpaena86cOXr00Uc1c+ZMxcTEqLS0VKmpqZo/f74kqU+fPnrvvfe0cOFCOZ1OSdKCBQs0YcIEjR07VpJUWlqqtWvXaunSpXrsscd09OhRvfTSS1q+fLkGDx4sSVq2bJn69OmjrVu36qabbgrmrQMAgDYoqCD0U0ePHpUkde7cOWD7q6++qldeeUVJSUm688479eSTTyouLk6S5Ha71bdvXyUmJlr1TqdTkydP1t69e3X99dfL7XYrKysr4JhOp1MFBQWSpPr6elVVVamoqMjaHxkZqaysLLndbklSVVWVfD5fwHF69+6tHj16yO12nzEIeb1eeb1e63Ftba0kyefzyefzBT2ftqZpBszi4mLOodE0X1ukP8ydBKc1rgvWdGgw5x8FM4PzDkKNjY0qKCjQLbfcouuuu87afv/996tnz55KTk7Whx9+qEcffVT79+/XG2+8IUnyeDwBIUiS9djj8Zy1pra2Vt9//70OHz6shoaGM9bs27fPOkZMTIw6duzYrKbpdX5q7ty5mjVrVrPtFRUVVpC
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"data['diff'].hist()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The dataset is a bit crappy because it has different frame step: ranging from predominantly 1 and 2 to sometimes have 3 and 4 as well. This inevitabily leads to difference in speed caluclations"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"if SRC_H is not None:\n",
|
|||
|
" H = np.loadtxt(SRC_H, delimiter=',')\n",
|
|||
|
"else:\n",
|
|||
|
" H = None"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"No H given, probably already projected data?\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>x_raw</th>\n",
|
|||
|
" <th>y_raw</th>\n",
|
|||
|
" <th>dt</th>\n",
|
|||
|
" <th>proj_x</th>\n",
|
|||
|
" <th>proj_y</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0.881595</td>\n",
|
|||
|
" <td>7.341152</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.870703</td>\n",
|
|||
|
" <td>7.309168</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.901374</td>\n",
|
|||
|
" <td>7.370044</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.924360</td>\n",
|
|||
|
" <td>7.432365</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.906583</td>\n",
|
|||
|
" <td>7.456334</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>15.214551</td>\n",
|
|||
|
" <td>10.027093</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.244872</td>\n",
|
|||
|
" <td>10.047117</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.318496</td>\n",
|
|||
|
" <td>10.015218</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.400203</td>\n",
|
|||
|
" <td>9.935355</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.416893</td>\n",
|
|||
|
" <td>10.051785</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 15 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff x_raw y_raw dt \\\n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN 0.881595 7.341152 NaN \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 0.870703 7.309168 0.083333 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 0.901374 7.370044 0.083333 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 0.924360 7.432365 0.083333 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 0.906583 7.456334 0.083333 \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN 15.214551 10.027093 NaN \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 15.244872 10.047117 0.083333 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 15.318496 10.015218 0.083333 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 15.400203 9.935355 0.083333 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 15.416893 10.051785 0.083333 \n",
|
|||
|
"\n",
|
|||
|
" proj_x proj_y \n",
|
|||
|
"0 0.855100 7.136193 \n",
|
|||
|
"1 0.873132 7.235233 \n",
|
|||
|
"2 0.890957 7.328989 \n",
|
|||
|
"3 0.907784 7.418187 \n",
|
|||
|
"4 0.923439 7.505012 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 \n",
|
|||
|
"320184 15.033432 9.870472 \n",
|
|||
|
"320185 15.211560 9.943236 \n",
|
|||
|
"320186 15.377673 10.008965 \n",
|
|||
|
"320187 15.538255 10.075935 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 15 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"if H is not None:\n",
|
|||
|
" print(\"Projecting data\")\n",
|
|||
|
" data['foot_x'] = data['pos_x'] + 0.5 * data['width']\n",
|
|||
|
" data['foot_y'] = data['pos_y'] + 0.5 * data['height']\n",
|
|||
|
" \n",
|
|||
|
" transformed = cv2.perspectiveTransform(np.array([data[['foot_x','foot_y']].to_numpy()]),H)[0]\n",
|
|||
|
" data['proj_x'], data['proj_y'] = transformed[:,0], transformed[:,1]\n",
|
|||
|
" data['proj_x'] = data['proj_x'].div(100) # cm to m\n",
|
|||
|
" data['proj_y'] = data['proj_y'].div(100) # cm to m\n",
|
|||
|
" # and shift to mean (THES NEED TO BE STORED AND REUSED IN LIVE SETTING)\n",
|
|||
|
" mean_x = data['proj_x'].mean()\n",
|
|||
|
" mean_y = data['proj_y'].mean()\n",
|
|||
|
" data['proj_x'] = data['proj_x'] - data['proj_x'].mean()\n",
|
|||
|
" data['proj_y'] = data['proj_y'] - data['proj_y'].mean()\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"No H given, probably already projected data?\")\n",
|
|||
|
" mean_x = 0\n",
|
|||
|
" mean_y = 0\n",
|
|||
|
" data['proj_x'] = data['x']\n",
|
|||
|
" data['proj_y'] = data['y']\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Deriving displacement, velocity and accelation from x and y\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>y_raw</th>\n",
|
|||
|
" <th>dt</th>\n",
|
|||
|
" <th>proj_x</th>\n",
|
|||
|
" <th>proj_y</th>\n",
|
|||
|
" <th>dx</th>\n",
|
|||
|
" <th>dy</th>\n",
|
|||
|
" <th>vx</th>\n",
|
|||
|
" <th>vy</th>\n",
|
|||
|
" <th>ax</th>\n",
|
|||
|
" <th>ay</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>7.341152</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>7.309168</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>0.018032</td>\n",
|
|||
|
" <td>0.099039</td>\n",
|
|||
|
" <td>0.216383</td>\n",
|
|||
|
" <td>1.188473</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>7.370044</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>0.017825</td>\n",
|
|||
|
" <td>0.093756</td>\n",
|
|||
|
" <td>0.213899</td>\n",
|
|||
|
" <td>1.125077</td>\n",
|
|||
|
" <td>-0.029812</td>\n",
|
|||
|
" <td>-0.760753</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>7.432365</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>0.016827</td>\n",
|
|||
|
" <td>0.089198</td>\n",
|
|||
|
" <td>0.201924</td>\n",
|
|||
|
" <td>1.070371</td>\n",
|
|||
|
" <td>-0.143699</td>\n",
|
|||
|
" <td>-0.656466</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>7.456334</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>0.015655</td>\n",
|
|||
|
" <td>0.086825</td>\n",
|
|||
|
" <td>0.187865</td>\n",
|
|||
|
" <td>1.041902</td>\n",
|
|||
|
" <td>-0.168701</td>\n",
|
|||
|
" <td>-0.341637</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>10.027093</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>10.047117</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>0.192955</td>\n",
|
|||
|
" <td>0.083971</td>\n",
|
|||
|
" <td>2.315463</td>\n",
|
|||
|
" <td>1.007656</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>10.015218</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>0.178128</td>\n",
|
|||
|
" <td>0.072764</td>\n",
|
|||
|
" <td>2.137542</td>\n",
|
|||
|
" <td>0.873173</td>\n",
|
|||
|
" <td>-2.135059</td>\n",
|
|||
|
" <td>-1.613797</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>9.935355</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>0.166113</td>\n",
|
|||
|
" <td>0.065728</td>\n",
|
|||
|
" <td>1.993352</td>\n",
|
|||
|
" <td>0.788742</td>\n",
|
|||
|
" <td>-1.730279</td>\n",
|
|||
|
" <td>-1.013172</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>10.051785</td>\n",
|
|||
|
" <td>0.083333</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>0.160582</td>\n",
|
|||
|
" <td>0.066971</td>\n",
|
|||
|
" <td>1.926987</td>\n",
|
|||
|
" <td>0.803649</td>\n",
|
|||
|
" <td>-0.796376</td>\n",
|
|||
|
" <td>0.178886</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 21 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff ... y_raw dt \\\n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN ... 7.341152 NaN \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 ... 7.309168 0.083333 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 ... 7.370044 0.083333 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 ... 7.432365 0.083333 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 ... 7.456334 0.083333 \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN ... 10.027093 NaN \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 ... 10.047117 0.083333 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 ... 10.015218 0.083333 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 ... 9.935355 0.083333 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 ... 10.051785 0.083333 \n",
|
|||
|
"\n",
|
|||
|
" proj_x proj_y dx dy vx vy \\\n",
|
|||
|
"0 0.855100 7.136193 NaN NaN NaN NaN \n",
|
|||
|
"1 0.873132 7.235233 0.018032 0.099039 0.216383 1.188473 \n",
|
|||
|
"2 0.890957 7.328989 0.017825 0.093756 0.213899 1.125077 \n",
|
|||
|
"3 0.907784 7.418187 0.016827 0.089198 0.201924 1.070371 \n",
|
|||
|
"4 0.923439 7.505012 0.015655 0.086825 0.187865 1.041902 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 NaN NaN NaN NaN \n",
|
|||
|
"320184 15.033432 9.870472 0.192955 0.083971 2.315463 1.007656 \n",
|
|||
|
"320185 15.211560 9.943236 0.178128 0.072764 2.137542 0.873173 \n",
|
|||
|
"320186 15.377673 10.008965 0.166113 0.065728 1.993352 0.788742 \n",
|
|||
|
"320187 15.538255 10.075935 0.160582 0.066971 1.926987 0.803649 \n",
|
|||
|
"\n",
|
|||
|
" ax ay \n",
|
|||
|
"0 NaN NaN \n",
|
|||
|
"1 NaN NaN \n",
|
|||
|
"2 -0.029812 -0.760753 \n",
|
|||
|
"3 -0.143699 -0.656466 \n",
|
|||
|
"4 -0.168701 -0.341637 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"320183 NaN NaN \n",
|
|||
|
"320184 NaN NaN \n",
|
|||
|
"320185 -2.135059 -1.613797 \n",
|
|||
|
"320186 -1.730279 -1.013172 \n",
|
|||
|
"320187 -0.796376 0.178886 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 21 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(\"Deriving displacement, velocity and accelation from x and y\")\n",
|
|||
|
"data['dx'] = data.groupby(['track_id'])['proj_x'].diff()\n",
|
|||
|
"data['dy'] = data.groupby(['track_id'])['proj_y'].diff()\n",
|
|||
|
"data['vx'] = data['dx'].div(data['dt'], axis=0)\n",
|
|||
|
"data['vy'] = data['dy'].div(data['dt'], axis=0)\n",
|
|||
|
"\n",
|
|||
|
"data['ax'] = data.groupby(['track_id'])['vx'].diff().div(data['dt'], axis=0)\n",
|
|||
|
"data['ay'] = data.groupby(['track_id'])['vy'].diff().div(data['dt'], axis=0)\n",
|
|||
|
"\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>dx</th>\n",
|
|||
|
" <th>dy</th>\n",
|
|||
|
" <th>vx</th>\n",
|
|||
|
" <th>vy</th>\n",
|
|||
|
" <th>ax</th>\n",
|
|||
|
" <th>ay</th>\n",
|
|||
|
" <th>v</th>\n",
|
|||
|
" <th>a</th>\n",
|
|||
|
" <th>heading</th>\n",
|
|||
|
" <th>d_heading</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.018032</td>\n",
|
|||
|
" <td>0.099039</td>\n",
|
|||
|
" <td>0.216383</td>\n",
|
|||
|
" <td>1.188473</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1.208011</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>79.681298</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.017825</td>\n",
|
|||
|
" <td>0.093756</td>\n",
|
|||
|
" <td>0.213899</td>\n",
|
|||
|
" <td>1.125077</td>\n",
|
|||
|
" <td>-0.029812</td>\n",
|
|||
|
" <td>-0.760753</td>\n",
|
|||
|
" <td>1.145230</td>\n",
|
|||
|
" <td>-0.753373</td>\n",
|
|||
|
" <td>79.235449</td>\n",
|
|||
|
" <td>-5.350188</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.016827</td>\n",
|
|||
|
" <td>0.089198</td>\n",
|
|||
|
" <td>0.201924</td>\n",
|
|||
|
" <td>1.070371</td>\n",
|
|||
|
" <td>-0.143699</td>\n",
|
|||
|
" <td>-0.656466</td>\n",
|
|||
|
" <td>1.089251</td>\n",
|
|||
|
" <td>-0.671740</td>\n",
|
|||
|
" <td>79.316807</td>\n",
|
|||
|
" <td>0.976297</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.015655</td>\n",
|
|||
|
" <td>0.086825</td>\n",
|
|||
|
" <td>0.187865</td>\n",
|
|||
|
" <td>1.041902</td>\n",
|
|||
|
" <td>-0.168701</td>\n",
|
|||
|
" <td>-0.341637</td>\n",
|
|||
|
" <td>1.058703</td>\n",
|
|||
|
" <td>-0.366576</td>\n",
|
|||
|
" <td>79.778828</td>\n",
|
|||
|
" <td>5.544252</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.192955</td>\n",
|
|||
|
" <td>0.083971</td>\n",
|
|||
|
" <td>2.315463</td>\n",
|
|||
|
" <td>1.007656</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>2.525221</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>23.517970</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.178128</td>\n",
|
|||
|
" <td>0.072764</td>\n",
|
|||
|
" <td>2.137542</td>\n",
|
|||
|
" <td>0.873173</td>\n",
|
|||
|
" <td>-2.135059</td>\n",
|
|||
|
" <td>-1.613797</td>\n",
|
|||
|
" <td>2.309007</td>\n",
|
|||
|
" <td>-2.594562</td>\n",
|
|||
|
" <td>22.219713</td>\n",
|
|||
|
" <td>-15.579091</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.166113</td>\n",
|
|||
|
" <td>0.065728</td>\n",
|
|||
|
" <td>1.993352</td>\n",
|
|||
|
" <td>0.788742</td>\n",
|
|||
|
" <td>-1.730279</td>\n",
|
|||
|
" <td>-1.013172</td>\n",
|
|||
|
" <td>2.143727</td>\n",
|
|||
|
" <td>-1.983366</td>\n",
|
|||
|
" <td>21.588019</td>\n",
|
|||
|
" <td>-7.580324</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.160582</td>\n",
|
|||
|
" <td>0.066971</td>\n",
|
|||
|
" <td>1.926987</td>\n",
|
|||
|
" <td>0.803649</td>\n",
|
|||
|
" <td>-0.796376</td>\n",
|
|||
|
" <td>0.178886</td>\n",
|
|||
|
" <td>2.087853</td>\n",
|
|||
|
" <td>-0.670484</td>\n",
|
|||
|
" <td>22.638547</td>\n",
|
|||
|
" <td>12.606340</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 25 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff ... dx dy vx \\\n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
|
|||
|
"... ... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
|
|||
|
"\n",
|
|||
|
" vy ax ay v a heading d_heading \n",
|
|||
|
"0 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
"1 1.188473 NaN NaN 1.208011 NaN 79.681298 NaN \n",
|
|||
|
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
|
|||
|
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
|
|||
|
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
"320184 1.007656 NaN NaN 2.525221 NaN 23.517970 NaN \n",
|
|||
|
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
|
|||
|
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
|
|||
|
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 25 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# then we need the velocity itself\n",
|
|||
|
"data['v'] = np.sqrt(data['vx'].pow(2) + data['vy'].pow(2))\n",
|
|||
|
"# and derive acceleration\n",
|
|||
|
"data['a'] = data.groupby(['track_id'])['v'].diff().div(data['dt'], axis=0)\n",
|
|||
|
"\n",
|
|||
|
"# we can calculate heading based on the velocity components\n",
|
|||
|
"data['heading'] = (np.arctan2(data['vy'], data['vx']) * 180 / np.pi) % 360\n",
|
|||
|
"\n",
|
|||
|
"# and derive it to get the rate of change of the heading\n",
|
|||
|
"data['d_heading'] = data.groupby(['track_id'])['heading'].diff().div(data['dt'], axis=0)\n",
|
|||
|
"\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>track_id</th>\n",
|
|||
|
" <th>l</th>\n",
|
|||
|
" <th>t</th>\n",
|
|||
|
" <th>w</th>\n",
|
|||
|
" <th>h</th>\n",
|
|||
|
" <th>x</th>\n",
|
|||
|
" <th>y</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>diff</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>dx</th>\n",
|
|||
|
" <th>dy</th>\n",
|
|||
|
" <th>vx</th>\n",
|
|||
|
" <th>vy</th>\n",
|
|||
|
" <th>ax</th>\n",
|
|||
|
" <th>ay</th>\n",
|
|||
|
" <th>v</th>\n",
|
|||
|
" <th>a</th>\n",
|
|||
|
" <th>heading</th>\n",
|
|||
|
" <th>d_heading</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.566162</td>\n",
|
|||
|
" <td>88.795326</td>\n",
|
|||
|
" <td>173.917542</td>\n",
|
|||
|
" <td>0.855100</td>\n",
|
|||
|
" <td>7.136193</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1.208011</td>\n",
|
|||
|
" <td>-0.753373</td>\n",
|
|||
|
" <td>79.681298</td>\n",
|
|||
|
" <td>-5.350188</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>565.116699</td>\n",
|
|||
|
" <td>88.801704</td>\n",
|
|||
|
" <td>171.334290</td>\n",
|
|||
|
" <td>0.873132</td>\n",
|
|||
|
" <td>7.235233</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.018032</td>\n",
|
|||
|
" <td>0.099039</td>\n",
|
|||
|
" <td>0.216383</td>\n",
|
|||
|
" <td>1.188473</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1.208011</td>\n",
|
|||
|
" <td>-0.753373</td>\n",
|
|||
|
" <td>79.681298</td>\n",
|
|||
|
" <td>-5.350188</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>11.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874573</td>\n",
|
|||
|
" <td>90.596596</td>\n",
|
|||
|
" <td>177.199951</td>\n",
|
|||
|
" <td>0.890957</td>\n",
|
|||
|
" <td>7.328989</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.017825</td>\n",
|
|||
|
" <td>0.093756</td>\n",
|
|||
|
" <td>0.213899</td>\n",
|
|||
|
" <td>1.125077</td>\n",
|
|||
|
" <td>-0.029812</td>\n",
|
|||
|
" <td>-0.760753</td>\n",
|
|||
|
" <td>1.145230</td>\n",
|
|||
|
" <td>-0.753373</td>\n",
|
|||
|
" <td>79.235449</td>\n",
|
|||
|
" <td>-5.350188</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>564.874268</td>\n",
|
|||
|
" <td>90.928131</td>\n",
|
|||
|
" <td>183.125732</td>\n",
|
|||
|
" <td>0.907784</td>\n",
|
|||
|
" <td>7.418187</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.016827</td>\n",
|
|||
|
" <td>0.089198</td>\n",
|
|||
|
" <td>0.201924</td>\n",
|
|||
|
" <td>1.070371</td>\n",
|
|||
|
" <td>-0.143699</td>\n",
|
|||
|
" <td>-0.656466</td>\n",
|
|||
|
" <td>1.089251</td>\n",
|
|||
|
" <td>-0.671740</td>\n",
|
|||
|
" <td>79.316807</td>\n",
|
|||
|
" <td>0.976297</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>13.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.000000</td>\n",
|
|||
|
" <td>569.931213</td>\n",
|
|||
|
" <td>86.213280</td>\n",
|
|||
|
" <td>180.774292</td>\n",
|
|||
|
" <td>0.923439</td>\n",
|
|||
|
" <td>7.505012</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.015655</td>\n",
|
|||
|
" <td>0.086825</td>\n",
|
|||
|
" <td>0.187865</td>\n",
|
|||
|
" <td>1.041902</td>\n",
|
|||
|
" <td>-0.168701</td>\n",
|
|||
|
" <td>-0.341637</td>\n",
|
|||
|
" <td>1.058703</td>\n",
|
|||
|
" <td>-0.366576</td>\n",
|
|||
|
" <td>79.778828</td>\n",
|
|||
|
" <td>5.544252</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320183</th>\n",
|
|||
|
" <td>60159.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1830.709717</td>\n",
|
|||
|
" <td>651.257446</td>\n",
|
|||
|
" <td>150.202515</td>\n",
|
|||
|
" <td>157.239746</td>\n",
|
|||
|
" <td>14.840476</td>\n",
|
|||
|
" <td>9.786501</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>2.525221</td>\n",
|
|||
|
" <td>-2.594562</td>\n",
|
|||
|
" <td>23.517970</td>\n",
|
|||
|
" <td>-15.579091</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320184</th>\n",
|
|||
|
" <td>60160.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1834.013672</td>\n",
|
|||
|
" <td>649.612122</td>\n",
|
|||
|
" <td>153.686646</td>\n",
|
|||
|
" <td>160.874023</td>\n",
|
|||
|
" <td>15.033432</td>\n",
|
|||
|
" <td>9.870472</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.192955</td>\n",
|
|||
|
" <td>0.083971</td>\n",
|
|||
|
" <td>2.315463</td>\n",
|
|||
|
" <td>1.007656</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>2.525221</td>\n",
|
|||
|
" <td>-2.594562</td>\n",
|
|||
|
" <td>23.517970</td>\n",
|
|||
|
" <td>-15.579091</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320185</th>\n",
|
|||
|
" <td>60161.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1845.373047</td>\n",
|
|||
|
" <td>651.249756</td>\n",
|
|||
|
" <td>147.178589</td>\n",
|
|||
|
" <td>153.729248</td>\n",
|
|||
|
" <td>15.211560</td>\n",
|
|||
|
" <td>9.943236</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.178128</td>\n",
|
|||
|
" <td>0.072764</td>\n",
|
|||
|
" <td>2.137542</td>\n",
|
|||
|
" <td>0.873173</td>\n",
|
|||
|
" <td>-2.135059</td>\n",
|
|||
|
" <td>-1.613797</td>\n",
|
|||
|
" <td>2.309007</td>\n",
|
|||
|
" <td>-2.594562</td>\n",
|
|||
|
" <td>22.219713</td>\n",
|
|||
|
" <td>-15.579091</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320186</th>\n",
|
|||
|
" <td>60162.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1857.388916</td>\n",
|
|||
|
" <td>650.908203</td>\n",
|
|||
|
" <td>136.407349</td>\n",
|
|||
|
" <td>142.354614</td>\n",
|
|||
|
" <td>15.377673</td>\n",
|
|||
|
" <td>10.008965</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.166113</td>\n",
|
|||
|
" <td>0.065728</td>\n",
|
|||
|
" <td>1.993352</td>\n",
|
|||
|
" <td>0.788742</td>\n",
|
|||
|
" <td>-1.730279</td>\n",
|
|||
|
" <td>-1.013172</td>\n",
|
|||
|
" <td>2.143727</td>\n",
|
|||
|
" <td>-1.983366</td>\n",
|
|||
|
" <td>21.588019</td>\n",
|
|||
|
" <td>-7.580324</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>320187</th>\n",
|
|||
|
" <td>60163.0</td>\n",
|
|||
|
" <td>3632</td>\n",
|
|||
|
" <td>1862.792725</td>\n",
|
|||
|
" <td>658.719971</td>\n",
|
|||
|
" <td>141.984253</td>\n",
|
|||
|
" <td>149.052307</td>\n",
|
|||
|
" <td>15.538255</td>\n",
|
|||
|
" <td>10.075935</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.160582</td>\n",
|
|||
|
" <td>0.066971</td>\n",
|
|||
|
" <td>1.926987</td>\n",
|
|||
|
" <td>0.803649</td>\n",
|
|||
|
" <td>-0.796376</td>\n",
|
|||
|
" <td>0.178886</td>\n",
|
|||
|
" <td>2.087853</td>\n",
|
|||
|
" <td>-0.670484</td>\n",
|
|||
|
" <td>22.638547</td>\n",
|
|||
|
" <td>12.606340</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>320188 rows × 25 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" frame_id track_id l t w h \\\n",
|
|||
|
"0 9.0 1 0.000000 565.566162 88.795326 173.917542 \n",
|
|||
|
"1 10.0 1 0.000000 565.116699 88.801704 171.334290 \n",
|
|||
|
"2 11.0 1 0.000000 564.874573 90.596596 177.199951 \n",
|
|||
|
"3 12.0 1 0.000000 564.874268 90.928131 183.125732 \n",
|
|||
|
"4 13.0 1 0.000000 569.931213 86.213280 180.774292 \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"320183 60159.0 3632 1830.709717 651.257446 150.202515 157.239746 \n",
|
|||
|
"320184 60160.0 3632 1834.013672 649.612122 153.686646 160.874023 \n",
|
|||
|
"320185 60161.0 3632 1845.373047 651.249756 147.178589 153.729248 \n",
|
|||
|
"320186 60162.0 3632 1857.388916 650.908203 136.407349 142.354614 \n",
|
|||
|
"320187 60163.0 3632 1862.792725 658.719971 141.984253 149.052307 \n",
|
|||
|
"\n",
|
|||
|
" x y state diff ... dx dy vx \\\n",
|
|||
|
"0 0.855100 7.136193 2.0 NaN ... NaN NaN NaN \n",
|
|||
|
"1 0.873132 7.235233 2.0 1.0 ... 0.018032 0.099039 0.216383 \n",
|
|||
|
"2 0.890957 7.328989 2.0 1.0 ... 0.017825 0.093756 0.213899 \n",
|
|||
|
"3 0.907784 7.418187 2.0 1.0 ... 0.016827 0.089198 0.201924 \n",
|
|||
|
"4 0.923439 7.505012 2.0 1.0 ... 0.015655 0.086825 0.187865 \n",
|
|||
|
"... ... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 14.840476 9.786501 2.0 NaN ... NaN NaN NaN \n",
|
|||
|
"320184 15.033432 9.870472 2.0 1.0 ... 0.192955 0.083971 2.315463 \n",
|
|||
|
"320185 15.211560 9.943236 2.0 1.0 ... 0.178128 0.072764 2.137542 \n",
|
|||
|
"320186 15.377673 10.008965 2.0 1.0 ... 0.166113 0.065728 1.993352 \n",
|
|||
|
"320187 15.538255 10.075935 2.0 1.0 ... 0.160582 0.066971 1.926987 \n",
|
|||
|
"\n",
|
|||
|
" vy ax ay v a heading d_heading \n",
|
|||
|
"0 NaN NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
|
|||
|
"1 1.188473 NaN NaN 1.208011 -0.753373 79.681298 -5.350188 \n",
|
|||
|
"2 1.125077 -0.029812 -0.760753 1.145230 -0.753373 79.235449 -5.350188 \n",
|
|||
|
"3 1.070371 -0.143699 -0.656466 1.089251 -0.671740 79.316807 0.976297 \n",
|
|||
|
"4 1.041902 -0.168701 -0.341637 1.058703 -0.366576 79.778828 5.544252 \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"320183 NaN NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
|
|||
|
"320184 1.007656 NaN NaN 2.525221 -2.594562 23.517970 -15.579091 \n",
|
|||
|
"320185 0.873173 -2.135059 -1.613797 2.309007 -2.594562 22.219713 -15.579091 \n",
|
|||
|
"320186 0.788742 -1.730279 -1.013172 2.143727 -1.983366 21.588019 -7.580324 \n",
|
|||
|
"320187 0.803649 -0.796376 0.178886 2.087853 -0.670484 22.638547 12.606340 \n",
|
|||
|
"\n",
|
|||
|
"[320188 rows x 25 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# we can backfill the v and a, so that our model can make estimations\n",
|
|||
|
"# based on these assumed values\n",
|
|||
|
"data['v'] = data.groupby(['track_id'])['v'].bfill()\n",
|
|||
|
"data['a'] = data.groupby(['track_id'])['a'].bfill()\n",
|
|||
|
"\n",
|
|||
|
"data['heading'] = data.groupby(['track_id'])['heading'].bfill()\n",
|
|||
|
"data['d_heading'] = data.groupby(['track_id'])['d_heading'].bfill()\n",
|
|||
|
"data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"312423 items in filtered set, out of 320188 in total set\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"filtered_data = data.groupby(['track_id']).filter(lambda group: len(group) >= window+1) # a lenght of 3 is neccessary to have all relevant derivatives of position\n",
|
|||
|
"filtered_data = filtered_data.set_index(['track_id', 'frame_id']) # use for quick access\n",
|
|||
|
"print(filtered_data.shape[0], \"items in filtered set, out of\", data.shape[0], \"in total set\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1263 training tracks, 316 test tracks\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"track_ids = filtered_data.index.unique('track_id').to_numpy()\n",
|
|||
|
"np.random.shuffle(track_ids)\n",
|
|||
|
"test_offset_idx = int(len(track_ids) * .8)\n",
|
|||
|
"training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n",
|
|||
|
"print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"here, draw out a sample track to see if it looks alright. **unfortunately the imate isn't mapped properly**."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1058\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"0 0\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3wcZ53/3zOzXbvqvcuW5W7LPXGL03F6AZIAoYSSXDg4yh0cEDjg4McFOEghcIQEkpCQRpqdnjhO3Fss925Vq3dp+87M8/tjVCxLtuXEVX7er5e8q5lnZp5Zr2Y+862KEEIgkUgkEolEcppQz/QEJBKJRCKRnF9I8SGRSCQSieS0IsWHRCKRSCSS04oUHxKJRCKRSE4rUnxIJBKJRCI5rUjxIZFIJBKJ5LQixYdEIpFIJJLTihQfEolEIpFITiu2Mz2BIzFNk7q6Onw+H4qinOnpSCQSiUQiGQZCCLq7u8nOzkZVj23bOOvER11dHXl5eWd6GhKJRCKRSD4CNTU15ObmHnPMWSc+fD4fYE0+Pj7+DM9GIpFIJBLJcOjq6iIvL6/vPn4szjrx0etqiY+Pl+JDIpFIJJJzjOGETMiAU4lEIpFIJKeVExYfK1as4NprryU7OxtFUXj55ZePOvauu+5CURTuu+++jzFFiUQikUgkI4kTFh+BQICpU6fy0EMPHXPcSy+9xLp168jOzv7Ik5NIJBKJRDLyOOGYj8WLF7N48eJjjqmtreUb3/gGb731FldfffVHnpxEIpFIJJKRx0kPODVNk9tvv53/+I//YOLEiccdH4lEiEQifb93dXWd7ClJJBKJRCI5izjpAaf33nsvNpuNb37zm8Ma/6tf/YqEhIS+H1njQyKRSCSSkc1JFR8ffvgh999/P4899tiwq5P+4Ac/oLOzs++npqbmZE5JIpFIJBLJWcZJFR8rV66kqamJ/Px8bDYbNpuNqqoqvvvd71JYWDjkNk6ns6+mh6ztIZFIJBLJyOekxnzcfvvtXHbZZQOWXXnlldx+++186UtfOpmHkkgkEsl5jhCCSHknzlEJshfYOcYJiw+/38+BAwf6fq+oqGDLli0kJyeTn59PSkrKgPF2u53MzEzGjh378WcrkUgkEgmW8GhfVknw3UMk3j4O78S0Mz0lyQlwwm6XTZs2MW3aNKZNmwbAd77zHaZNm8ZPfvKTkz45iUQikUiGIry3jeC7hwDofH4/QogzPCPJiXDClo9Fixad0H9yZWXliR5CIpFIJJJBCFMQWF+PGdIJfNjYvzxs0PbcXuxpHlS3jbg5WSiqdMOczZx1jeUkEonkXEMIQd2+DrJLEmXswSlCmIKOD6oIvDV0RmSorJlQz3vH5GQcXtfpm5zkhJGN5SQSieRjUrWjlZd/X0b1ztYzPZURixGOHVV4HElgZR3ClG6YsxkpPiQSieRjsvW9mgGvkpNPtLxz2GMDa+oxI/opnI3k4yLdLhKJRHKCCFOwY0UtkaB1gzu0p73ntYNNr1cC4PTYmLQwR8YefExMw6TztQqilcMXH8jg07MeKT4kEonkBImGdda8eAA9ag5YLkzB+iXlANgcKiWzM3B67Gdiiuc8vcGl4cpOwltbTnj74MZGvPOl+DtbkW4XiUQi+QgcL+tPpn5+PETMoPPtqo8kPNAFne9USdfLWYwUHxKJRHKCKKqCph378qlpqnzq/ggIU+BfW4d/dR1xMzPgo36EUvyd1UjxIZFIJMdACEH4YMcAS4bDZeOWe2bjjh/apWJ3aUxYkMPedQ0fO+tCCEFg/YZBlhQhBLV720echaXX4tH1dhX+lbXwMU4vuLFRZr2cpUjxIZFIJMcgvK+dlr9sJ7KvfcByp8dGqDs25DaxsMGWd6pZ8+IBouGPZ/oPrFxJ9Re+QGDVqgHLq3e2Wem9u9o+1v7PNlSnjYx/m4ajwHfsgc7jmESk6+WsRooPiUQiOQah7VbMQXB7ywALyHBcKqYp2LW6/mM9fXe9+daAV7CsHtuWW2m9Bzc3feR9n63YEl2kfXUKiuOIW5QGznFJADjHJB1/R0JI68dZisx2kUgkksM4vIS3EIJgmXVzD21tJripEc/0dBx5PuLmZJE/IZnqnW0IBF2KIEEMvFmaumDD0nImzMsadtaLME3an34as7sbYRh0vvYaAF2vv05bzIcRM+mOOaiOTQVF5eCHTcSnuIGRld4bPdSNOCKbCAN8F+USV5qOluikvTOGXtN99J3ogq7lNcTNyURxytvd2YT835BIJJLD6I05ECH9iOXWjTC4uYnQ7jY809OZemkeNXvaKFdM3vREuavLhXpEhKQRM9m5qo5pl+UPz1oSDNH8wIOYnQPrWohwGHXJY2hAd97lMKoUgFjU6EvvdXpsjL0gE4fr3L+0h3Zb7iTXxBQSryqi47VywrvaCO9pJ3FxEQC2ZOcxxYfiseGdl4Ni107LnCXDR7pdJBKJ5DD6Yg7yjx5z4BiVQMfScpI6wuSPS2Kf3SCgwj67MWisELDupYNsfqdqWOZ/1eMm+Qufx5aVNXhHQMSdTEPmbOjpISN6jAPxqS4+/cNZI0J4ALjHJ5N861hSPjceW4qblNsnkHzrWNzjk/vGxM04djaMCOp0L6/Bv6pWul7OMkbGt1QikUhOIrZEFylfmED9f68fcn1oZwsvESOgQaPPyV6HJTre9MTI69KIEwodikmCUFBQEALWv1KOqqqUXpp3TAuIGQzR9vgTgywfvThDbRiae9DyQGcEp2fkXNKdhQkDflcUBU9p+oBl7pJk7GMSie3rOPqOdJPONysRQuBbkDsiXFIjAWn5kEgkkiHQm0NHXRcGHiXMw0aYVzo6ifbcz2IKvBAXQUeQKFQOt4MIE9YvKWfLspqjPoUL06TzlZdJvPkm1JQUgke4CxSg05tLxDU42FJRlPPyxho/P+f4g0xB17vVMvPlLEKKD4lEIhmC3piDoa6SHhQeJY6sIWz+jZpgldNKwdUAcVihCiNmDhIgQgjWHmzFNE3W7jhE/e8foO2vf6MhGqI+cbDrJ+TOGHK+Y+eMjFiPE8Vdkox30TAEyAirh3KuI8WHRCKRHIEQgp1ZlSTdUkLSrWOHHJOGiv8oFbA2ugzKHDrKEOLEiJlsWFreV//j/X3N3PaXdfxh+QE++9RW/p57IeGkNBoS4tBMc9D2TenTBy1TFCgqTT2RUxxRmEdmxUjOeqT4kEgkkh6EEGxs2MjK2pXctfWbbE7bx4aKdZb14ggdYUPhPxkce9E7bqMjRpdiHlWA7FxVhzAFb2yvB+ClzbVM69jKzM4yauwxahO9GKqCAFrjXOxPT6QyJZ6oI2HQ/lAga3Tixzv5cxhnSeLxB5lixFWDPZc5/2x0EolEcgRCCDY1biKkh/j6sq8zN3suAE/ufpI1HWv47byfMnF1+qDtLsLOpcRYxuBYgk4bPB8X4Q6/a5AAEQLWvHyQ5zZW83x3FwDVzV2UhqupTPehCAGKwt7sVNK6Q7TFudiflYJNN7D58nr2YaAoVkxIcmYcdsf5mU5q6iaBdQ3DGAhm1ADPqZ+T5PhIy4dEIjkv6bVyCCFYcWgFd7x1B/duuBeADQ0brNd66/Xhzif4v4zn0QrjBu3nB7gpPvJSqoDbgNG6xl6bMSDuo2+ICdk1UZSeVYZq46Ws63kq51ZCirN3kngjUQ54ShBRB3PK20G1gTAxjWjfvrpaw8Sig9N8RzpG1KDpwc1E97QffzCgqvKWd7YgLR8SieS8ZFXtKu5edje3jL2FLU1bAKjurgZANy1Lhi6s131GOfuSywk4dEqSsrm6YwGqUDEQuFD4JR4+h5/DO72ENNioGbhMg9FdGkPVN9VQKNJVmlUTQ/gJaR7aHUlsSZzKhR0bUQ0FVYAnFia3uxYNFwCKMNA0S6DYHCrTrsg/Ly0fsfYQeuPRs5IOJ+Urk7DFO0/xjCTDRcpAiURy3mEKkz9v/TMAz+97nr3te4e13bvRlTyWtRTHJ/MBK7yjDZMcVG7DMWi8z4AZEdsxn/JuCjj5UpfKZ2qf4zOHnsFuRnEIS8YEFS8AGeEWDmQmo2vWnmx6GK3H8qF
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import random\n",
|
|||
|
"if H:\n",
|
|||
|
" img_src = \"../DATASETS/hof/webcam20231103-2.png\"\n",
|
|||
|
" # dst = cv2.warpPerspective(img_src,H,(2500,1920))\n",
|
|||
|
" src_img = cv2.imread(img_src)\n",
|
|||
|
" print(src_img.shape)\n",
|
|||
|
" h1,w1 = src_img.shape[:2]\n",
|
|||
|
" corners = np.float32([[0,0], [w1, 0], [0, h1], [w1, h1]])\n",
|
|||
|
"\n",
|
|||
|
" print(corners)\n",
|
|||
|
" corners_projected = cv2.perspectiveTransform(corners.reshape((-1,4,2)), H)[0]\n",
|
|||
|
" print(corners_projected)\n",
|
|||
|
" [xmin, ymin] = np.int32(corners_projected.min(axis=0).ravel() - 0.5)\n",
|
|||
|
" [xmax, ymax] = np.int32(corners_projected.max(axis=0).ravel() + 0.5)\n",
|
|||
|
" print(xmin, xmax, ymin, ymax)\n",
|
|||
|
"\n",
|
|||
|
" dst = cv2.warpPerspective(src_img,H, (xmax, ymax))\n",
|
|||
|
" def plot_track(track_id: int):\n",
|
|||
|
" plt.gca().invert_yaxis()\n",
|
|||
|
"\n",
|
|||
|
" plt.imshow(dst, origin='lower', extent=[xmin/100-mean_x, xmax/100-mean_x, ymin/100-mean_y, ymax/100-mean_y])\n",
|
|||
|
" # plot scatter plot with x and y data \n",
|
|||
|
" \n",
|
|||
|
" ax = plt.scatter(\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_y'],\n",
|
|||
|
" marker=\"*\") \n",
|
|||
|
" ax.axes.invert_yaxis()\n",
|
|||
|
" plt.plot(\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_y']\n",
|
|||
|
" )\n",
|
|||
|
"else:\n",
|
|||
|
" def plot_track(track_id: int):\n",
|
|||
|
" ax = plt.scatter(\n",
|
|||
|
" filtered_data.loc[track_id,:]['x'],\n",
|
|||
|
" filtered_data.loc[track_id,:]['y'],\n",
|
|||
|
" marker=\"*\") \n",
|
|||
|
" plt.plot(\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[track_id,:]['proj_y']\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
|||
|
"# _track_id = 2188\n",
|
|||
|
"_track_id = random.choice(track_ids)\n",
|
|||
|
"print(_track_id)\n",
|
|||
|
"plot_track(_track_id)\n",
|
|||
|
"\n",
|
|||
|
"for track_id in random.choices(track_ids, k=100):\n",
|
|||
|
" plot_track(track_id)\n",
|
|||
|
" \n",
|
|||
|
"print(mean_x, mean_y)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now make the dataset:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# a=filtered_data.loc[1]\n",
|
|||
|
"# min(a.index.tolist())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" 0%| | 0/1263 [00:00<?, ?it/s]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"100%|██████████| 1263/1263 [01:42<00:00, 12.36it/s]\n",
|
|||
|
"100%|██████████| 316/316 [00:26<00:00, 11.75it/s]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def create_dataset(data, track_ids, window):\n",
|
|||
|
" X, y, = [], []\n",
|
|||
|
" for track_id in tqdm(track_ids):\n",
|
|||
|
" df = data.loc[track_id]\n",
|
|||
|
" start_frame = min(df.index.tolist())\n",
|
|||
|
" for step in range(len(df)-window-1):\n",
|
|||
|
" i = int(start_frame) + step\n",
|
|||
|
" # print(step, int(start_frame), i)\n",
|
|||
|
" feature = df.loc[i:i+window][in_fields]\n",
|
|||
|
" # target = df.loc[i+1:i+window+1][out_fields]\n",
|
|||
|
" target = df.loc[i+window+1][out_fields]\n",
|
|||
|
" X.append(feature.values)\n",
|
|||
|
" y.append(target.values)\n",
|
|||
|
" \n",
|
|||
|
" return torch.tensor(np.array(X), dtype=torch.float), torch.tensor(np.array(y), dtype=torch.float)\n",
|
|||
|
"\n",
|
|||
|
"X_train, y_train = create_dataset(filtered_data, training_ids, window)\n",
|
|||
|
"X_test, y_test = create_dataset(filtered_data, test_ids, window)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"X_train, y_train = X_train.to(device=device), y_train.to(device=device)\n",
|
|||
|
"X_test, y_test = X_test.to(device=device), y_test.to(device=device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from torch.utils.data import TensorDataset, DataLoader\n",
|
|||
|
"dataset_train = TensorDataset(X_train, y_train)\n",
|
|||
|
"loader_train = DataLoader(dataset_train, shuffle=True, batch_size=8)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Model give output for all timesteps, this should improve training. But we use only the last timestep for the prediction process"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class LSTMModel(nn.Module):\n",
|
|||
|
" # input_size : number of features in input at each time step\n",
|
|||
|
" # hidden_size : Number of LSTM units \n",
|
|||
|
" # num_layers : number of LSTM layers \n",
|
|||
|
" def __init__(self, input_size, hidden_size, num_layers): \n",
|
|||
|
" super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n",
|
|||
|
" self.lin1 = nn.Linear(input_size, hidden_size//2)\n",
|
|||
|
" self.lstm = nn.LSTM(hidden_size//2, hidden_size, num_layers, batch_first=True)\n",
|
|||
|
" self.linear = nn.Linear(hidden_size, output_size)\n",
|
|||
|
" # self.activation_v = nn.LeakyReLU(.01)\n",
|
|||
|
" # self.activation_heading = torch.remainder()\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x): # defines forward pass of the neural network\n",
|
|||
|
" out = self.lin1(x)\n",
|
|||
|
" out, h0 = self.lstm(out)\n",
|
|||
|
" # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n",
|
|||
|
" # print(out.shape)\n",
|
|||
|
" out = out[:, -1,:]\n",
|
|||
|
" # print(out.shape)\n",
|
|||
|
" out = self.linear(out)\n",
|
|||
|
" \n",
|
|||
|
" # torch.remainder(out[1], 360)\n",
|
|||
|
" # print('o',out.shape)\n",
|
|||
|
" return out\n",
|
|||
|
"\n",
|
|||
|
"model = LSTMModel(input_size, hidden_size, num_layers).to(device)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
|||
|
"loss_fn = nn.MSELoss()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def evaluate():\n",
|
|||
|
" # toggle evaluation mode\n",
|
|||
|
" model.eval()\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" y_pred = model(X_train.to(device=device))\n",
|
|||
|
" train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n",
|
|||
|
" y_pred = model(X_test.to(device=device))\n",
|
|||
|
" test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n",
|
|||
|
" print(\"Epoch %d: train RMSE %.4f, test RMSE %.4f\" % (epoch, train_rmse, test_rmse))\n",
|
|||
|
"\n",
|
|||
|
"def load_most_recent():\n",
|
|||
|
" paths = list(cache_path.glob(\"checkpoint_*.pt\"))\n",
|
|||
|
" if len(paths) < 1:\n",
|
|||
|
" print('Nothing found to load')\n",
|
|||
|
" return None, None\n",
|
|||
|
" paths.sort()\n",
|
|||
|
"\n",
|
|||
|
" print(f\"Loading {paths[-1]}\")\n",
|
|||
|
" return load_cache(path=paths[-1])\n",
|
|||
|
"\n",
|
|||
|
"def load_cache(epoch=None, path=None):\n",
|
|||
|
" if path is None:\n",
|
|||
|
" if epoch is None:\n",
|
|||
|
" raise RuntimeError(\"Either path or epoch must be given\")\n",
|
|||
|
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
|
|||
|
" else:\n",
|
|||
|
" print (path.stem)\n",
|
|||
|
" epoch = int(path.stem[-5:])\n",
|
|||
|
"\n",
|
|||
|
" cached = torch.load(path)\n",
|
|||
|
" \n",
|
|||
|
" optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
|
|||
|
" model.load_state_dict(cached['model_state_dict'])\n",
|
|||
|
" return epoch, cached['loss']\n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"def cache(epoch, loss):\n",
|
|||
|
" path = cache_path / f\"checkpoint_{epoch:05d}.pt\"\n",
|
|||
|
" print(f\"Cache to {path}\")\n",
|
|||
|
" torch.save({\n",
|
|||
|
" 'epoch': epoch,\n",
|
|||
|
" 'model_state_dict': model.state_dict(),\n",
|
|||
|
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
|||
|
" 'loss': loss,\n",
|
|||
|
" }, path)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Loading EXPERIMENTS/cache/hof2/checkpoint_00005.pt\n",
|
|||
|
"checkpoint_00005\n",
|
|||
|
"starting from epoch 5 (loss: nan)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" 0%| | 0/95 [00:00<?, ?it/s]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" 4%|▍ | 4/95 [07:37<2:53:27, 114.37s/it]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Cache to EXPERIMENTS/cache/hof2/checkpoint_00010.pt\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "RuntimeError",
|
|||
|
"evalue": "CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|||
|
"Cell \u001b[0;32mIn[31], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 30\u001b[0m cache(epoch, loss)\n\u001b[0;32m---> 31\u001b[0m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m evaluate()\n",
|
|||
|
"Cell \u001b[0;32mIn[30], line 5\u001b[0m, in \u001b[0;36mevaluate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m train_rmse \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msqrt(loss_fn(y_pred, y_train))\n\u001b[1;32m 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_test\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
|
|||
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
|||
|
"Cell \u001b[0;32mIn[28], line 15\u001b[0m, in \u001b[0;36mLSTMModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x): \u001b[38;5;66;03m# defines forward pass of the neural network\u001b[39;00m\n\u001b[1;32m 14\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin1(x)\n\u001b[0;32m---> 15\u001b[0m out, h0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# print(out.shape)\u001b[39;00m\n\u001b[1;32m 18\u001b[0m out \u001b[38;5;241m=\u001b[39m out[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n",
|
|||
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
|||
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:769\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_forward_args(\u001b[38;5;28minput\u001b[39m, hx, batch_sizes)\n\u001b[1;32m 768\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 769\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 772\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\u001b[38;5;28minput\u001b[39m, batch_sizes, hx, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_weights, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias,\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional)\n",
|
|||
|
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 28.40 GiB (GPU 0; 23.59 GiB total capacity; 15.01 GiB already allocated; 7.20 GiB free; 15.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"start_epoch, loss = load_most_recent()\n",
|
|||
|
"if start_epoch is None:\n",
|
|||
|
" start_epoch = 0\n",
|
|||
|
"else:\n",
|
|||
|
" print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n",
|
|||
|
"\n",
|
|||
|
"# Train Network\n",
|
|||
|
"for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n",
|
|||
|
" # toggle train mode\n",
|
|||
|
" model.train()\n",
|
|||
|
" for batch_idx, (data, targets) in enumerate(loader_train):\n",
|
|||
|
" # Get data to cuda if possible\n",
|
|||
|
" data = data.to(device=device).squeeze(1)\n",
|
|||
|
" targets = targets.to(device=device)\n",
|
|||
|
"\n",
|
|||
|
" # forward\n",
|
|||
|
" scores = model(data)\n",
|
|||
|
" loss = loss_fn(scores, targets)\n",
|
|||
|
"\n",
|
|||
|
" # backward\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" loss.backward()\n",
|
|||
|
"\n",
|
|||
|
" # gradient descent update step/adam step\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
"\n",
|
|||
|
" if epoch % 5 != 0:\n",
|
|||
|
" continue\n",
|
|||
|
"\n",
|
|||
|
" cache(epoch, loss)\n",
|
|||
|
" evaluate()\n",
|
|||
|
"\n",
|
|||
|
"evaluate()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "NameError",
|
|||
|
"evalue": "name 'model' is not defined",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|||
|
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 3\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m model(X_train\u001b[38;5;241m.\u001b[39mto(device\u001b[38;5;241m=\u001b[39mdevice))\n",
|
|||
|
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model.eval()\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" y_pred = model(X_train.to(device=device))\n",
|
|||
|
" print(y_pred.shape, y_train.shape)\n",
|
|||
|
"y_train, y_pred"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(0, 0)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"mean_x, mean_y"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import scipy\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def predict_and_plot(feature, steps = 50):\n",
|
|||
|
" lenght = feature.shape[0]\n",
|
|||
|
" # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n",
|
|||
|
" # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for i in range(steps):\n",
|
|||
|
" # predict_f = scipy.ndimage.uniform_filter(feature)\n",
|
|||
|
" # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n",
|
|||
|
" # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n",
|
|||
|
" # bathc size of one, so feature as single item in array\n",
|
|||
|
" X = torch.tensor([feature], dtype=torch.float).to(device)\n",
|
|||
|
" # print(X.shape)\n",
|
|||
|
" s = model(X)[0].cpu()\n",
|
|||
|
" \n",
|
|||
|
" # proj_x proj_y v heading a d_heading\n",
|
|||
|
" # next_step = feature\n",
|
|||
|
" dt = 1/ FPS\n",
|
|||
|
" v = np.sqrt(s[0]**2 + s[1]**2)\n",
|
|||
|
" heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n",
|
|||
|
" a = (v - feature[-1][2]) / dt\n",
|
|||
|
" d_heading = (heading - feature[-1][5])\n",
|
|||
|
" # print(s)\n",
|
|||
|
" feature = np.append(feature, [[feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]], axis=0)\n",
|
|||
|
" \n",
|
|||
|
" # print(next_step, nxt)\n",
|
|||
|
" plt.plot(feature[:lenght,0], feature[:lenght,1], c='orange')\n",
|
|||
|
" plt.plot(feature[lenght-1:,0], feature[lenght-1:,1], c='red')\n",
|
|||
|
" plt.scatter(feature[lenght:,0], feature[lenght:,1], c='red')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"2515\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAACPu0lEQVR4nO39eZRkd3nfj7/vvtRe1dv0aEYIicUIgYkxmBBiEfS1rAMYLDssJphgx45tYQIiGHMSIHhTMN8QASbEcI4Njg14ScA23xiHsMn+GTACY4ONtYCWkWZ6eqv13rr7/f3R/Txzq7tnk7pr6X5e5/SZ6Vq6PlV1qz7v+yzvR8nzPIcgCIIgCMKYUCe9AEEQBEEQjhYiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCsiPgRBEARBGCv6pBewkyzLcPr0aVQqFSiKMunlCIIgCIJwCeR5jn6/j+XlZajqhWMbUyc+Tp8+jRMnTkx6GYIgCIIgPAJOnTqFK6644oK3mTrxUalUAGwtvlqtTng1giAIgiBcCr1eDydOnOB9/EJMnfigVEu1WhXxIQiCIAgzxqWUTEjBqSAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY0XEhyAIgiAIY+Wyxccdd9yBF77whVheXoaiKPjEJz5x3tv+zM/8DBRFwe233/4oligIgiAIwmHissWH53l46lOfive9730XvN3HP/5xfOlLX8Ly8vIjXpwgCIIgCIePyx4sd9NNN+Gmm2664G0efvhh/PzP/zz+/M//HM9//vMf8eIEQRAmSZ7niKIIURRBURTous4/giA8cvb9E5RlGV75ylfijW98I6699tr9/vOCIAgHCgmO4XCIIAiQZdmu25AQMQyDxYiqqjBNcwIrFoTZY9/Fxzve8Q7ouo7Xvva1l3T7MAwRhiH/3uv19ntJgiAIFyXPc3ieh8FgMCI4VFWFZVkAgCRJkCQJ8jxHHMeI43jX35FUsyBcnH0VH1/96lfx7ne/G1/72tegKMol3ee2227D29/+9v1chiAIwmXT6XQwHA4BbAkOx3Fg2zZM0xz5PsvzHGmaIo5jFiN0P0EQLo19bbX9i7/4C6yuruLkyZMcinzggQfwhje8AY95zGP2vM+b3/xmdLtd/jl16tR+LkkQBOGSKJVKUFUV9Xodi4uLqNVqsCxr14kUpVwcx0GlUkGj0cDS0hIMw0C9Xp/M4gVhxtjXyMcrX/lK3HDDDSOX3XjjjXjlK1+JV7/61Xvex7IsDmkKgiBMCtM0sbi4eMlR2yKqqmJ+fv4AViUIh5PLFh+DwQD33nsv/37ffffh61//OprNJk6ePIlWqzVye8MwsLS0hCc84QmPfrWCIAgHyCMRHoIgXD6XLT7uvPNOPPe5z+Xfb731VgDAq171KnzoQx/at4UJgiAIgnA4uWzxcf311yPP80u+/f3333+5DyEIgiAIwiFGnHIEQZha8jzHmTNn+HfbttFsNie4IkEQ9gMZLCcIwtSSpunI71KTIQiHA4l8CIIwtei6jmazicFgAMdxpDNOEA4JIj4EQZhqbNuGbduTXoYgCPuIpF0EQRAEQRgrIj4EQRAEQRgrIj4EQRAEQRgrUvMhCMLMEwQBPM+DaZqoVCqTXo4gCBdBIh+CIMw8eZ4jDEMEQTDppQiCcAmI+BAEYeYxTRMAEMcxsiyb8GoEQbgYIj4EQZh5NE2Drm9lkaMomvBqBEG4GCI+BEE4FFD0Q8SHIEw/Ij4EQTgUiPgQhNlBxIcgCIcCsl6P4/iyJm8LgjB+RHwIgnAoUBQFURQhCAKJfgjClCPiQxCEQ0EcxwjDEL7vIwzDSS9HEIQLICZjgiBMnCiKkCQJoihCmqao1+vQNO2S75/nOTY2NmAYBoIgEPEhCFOORD4EQZgoWZZhfX0dnU6Hoxbdbvey/gbVeBiGAQBot9tS9yEIU4yID0EQJoqiKLsuC4IAaZpe8t9IkgTAlt+HqqpI01TqPgRhihHxIQjCRFEUhVMs5XKZoxeXIx7IZMwwDFQqFZTLZREfgjDFSM2HIAgTZ3Fxkf/f6/W4eNRxnEu6v6ZpWFhYAACUSiV0Oh2EYShD5gRhSpHIhyAIU8WjNQsjv48oimTOiyBMKSI+BGEHSZLA8zz0+32ZkjoBSHwkSXJZdR+EpmmcupGuF0GYTkR8CMI2WZah3W5jdXUV3W4X/X5fNq8DJssyDAaDkc4UVVUftXig6Ie8f4IwnUjNhyBga5PqdDp8pm1ZFnRd57NwYX8ZDodsCAZsCQ7Xdfl627YRxzGGw+HI5ZeKZVkYDAYiPgRhShHxIRx5giDA5uYmAEDXddTrdREdB0y32x2px/A8b0RkOI7DkackSaDrl/dVZZomFEVBmqaP6P6CIBwsknYRjjRRFKHdbgPY2vDm5+dFeEyAnYWhuq6z/we9P5eDoij8Pkr0QxCmDxEfwpElSRJsbm4iz3PYto16vb6n4ZWw/9i2DQBwXRdzc3OYm5vbdZtarQbg3MyWy4XqPqRoWBCmDxEfwpEkTVNsbGwgyzKYpolGoyHCY4yQ+AjDEKZp7jnHxXVdlEolAMBgMHjEjyEtt4IwfYj4ECZOnucYDodje7wsy7C5uYk0TaHrOprNpgiPMWNZFtdkXMjPo1wuA9gSKXEcX9Zj6LoOXdeR57lEPwRhypAqLGFi5HkOz/PgeR7SNIWqqhwq3y+SJIHv+9zKSZNT8zyHpmlotVpQVdHg40ZRFNi2jeFwiCAIzltno2kaHMfBcDiE53mo1+uX/Bh5nnP30uUKF0EQDhYRH8JE2Nnaqmnavk4hHQwGCIIAcRzv+XdVVUWz2bysse3C/lIUH9Vq9by3K5VKGA6HGA6HqFQqu96zKIoQxzG3RxP03iuKwumbo0AcxyzmVVWFpmkc2aP0kwhuYdKI+BDGSpZl6PV67O+gaRoqlQocx9m31Ee/30e/3+ffbdvm7gkaQKbrunwBTxjbtqEoCpIkQbvdhmVZex4HpmnCNE1EUQTf90fmtQwGA/R6Pf792LFjfH+adGua5pFptU3TFGtra7sup2OdxEej0Tjv3Jx2u40oitBqtY7M6yaMHzmyhLExHA5H/B1KpRKq1eq+iY4sy9DpdDi/XyqV4Louu2UK04WiKLAsC0EQcGTD9300m81dwrBUKiGKInieh3K5zMdMUXjshCJeR0lkFgtrVVVFnufI83xXwW273eboCNXckCiPoghpmmJ1dXVEzAnCfiLiQxgLvV6POxYMw0CtVttXP40kSbCxsYE0TaEoCqrV6pEKtc8qpVJppBg0iiIEQbDL1dS2bWiahjRN0ev1uA3Xsixuw52fnx/ZKHee7R8FDMOArutIkgTVahWu6yLLMv7RNA2bm5uI4/i8wq3f78MwDGiahocffhjz8/MsYJIk4Rod27a5cFgQLhcRH8KBkuc52u02bzDlchmVSmVfox0rKyv8u67raDQaEu2YEahOg1Ikqqpyi2wRRVFQq9WwubnJBcqu63KnEtV2FDmK4gPYEmok9l3X5doPol6vY319HXmec0pLURT4vo80TVEul7G+vs633ytyRIXcpmmiXq+zKJQ6KuFSEfEhHBhpmvJZlqIoqNfr580zP5rHKDI3N3ekwuyHgUajgfX1dZimiVardd7b2baNSqXC04apULWYhilCx0GSJIjjmIsvDzumaWJjY4NF3fLy8sj1hmFgYWEBWZaNiPRyuYzBYMBirdfrodVqQVGUkZ8sy6AoCuI4RhRFPIiRUjXHjx9HtVqVz6FwQUR8CAdCHMfspUGdJQdhW24YBqrVKn9Ryhfe7GEYBo4dO3ZJt61UKsiyDJ7nAcAFPUKoqLjb7SKOYzSbTfYNOcxQkWgQBMjzfM/ZNpqm7RJiiqJwMW+tVsPx48cv+DhJkqDT6SCKIpRKJbbB39jYQLlcls+icEFEfAj7ThRF2NzcRJZlbOJ1kFXz5XL5SGw
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
|||
|
"_track_id = 8701 # random.choice(track_ids)\n",
|
|||
|
"_track_id = 3880 # random.choice(track_ids)\n",
|
|||
|
"\n",
|
|||
|
"# _track_id = 2780\n",
|
|||
|
"\n",
|
|||
|
"for i in range(100):\n",
|
|||
|
" _track_id = random.choice(track_ids)\n",
|
|||
|
" plt.plot(\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_y'],\n",
|
|||
|
" c='grey', alpha=.2\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
"_track_id = random.choice(track_ids)\n",
|
|||
|
"# _track_id = 801\n",
|
|||
|
"print(_track_id)\n",
|
|||
|
"ax = plt.scatter(\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_y'],\n",
|
|||
|
" marker=\"*\") \n",
|
|||
|
"plt.plot(\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_x'],\n",
|
|||
|
" filtered_data.loc[_track_id,:]['proj_y']\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
|
|||
|
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:10][in_fields].values)\n",
|
|||
|
"predict_and_plot(filtered_data.loc[_track_id,:].iloc[:50][in_fields].values)\n",
|
|||
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
|
|||
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.4"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|