trap/test_custom_rnn.ipynb

1983 lines
313 KiB
Text
Raw Normal View History

2024-11-17 19:39:32 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ruben/suspicion/trap/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt # Visualization \n",
"import torch.nn as nn\n",
"import pandas_helper_calc # noqa # provides df.calc.derivative()\n",
"import pandas as pd\n",
"import cv2\n",
"import pathlib\n",
"from tqdm.autonotebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"FPS = 12\n",
"# SRC_CSV = \"EXPERIMENTS/hofext-maskrcnn/all.txt\"\n",
"# SRC_CSV = \"EXPERIMENTS/raw/generated/train/tracks.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/raw/hof-meter-maskrcnn2/train/tracks.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/20240426-hof-yolo/train/tracked.txt\"\n",
"SRC_CSV = \"EXPERIMENTS/raw/hof2/train/tracked.txt\"\n",
"# SRC_H = \"../DATASETS/hof/webcam20231103-2-homography.txt\"\n",
"SRC_H = None\n",
"CACHE_DIR = \"EXPERIMENTS/cache/hof2/\"\n",
2024-12-27 16:12:50 +01:00
"# SMOOTHING = True # hof-yolo is already smoothed, hof2 isn't\n",
"# SMOOTHING_WINDOW=3 #2"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"in_fields = ['x', 'y', 'vx', 'vy', 'ax', 'ay'] #, 'dt'] (WARNING: dt column contains NaN)\n",
2024-11-17 19:39:32 +01:00
"# out_fields = ['v', 'heading']\n",
"# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n",
"# an we can use simple MSE loss (I guess?)\n",
2024-12-27 16:12:50 +01:00
"out_fields = ['dx', 'dy']\n",
"SAMPLE_STEP = 5 # 1/5, for 12fps leads to effectively 12/5=2.4fps\n",
"GRID_SIZE = 2 # round items on a grid of 2 points per meter (None to disable rounding)\n",
"window = 8 #int(FPS*1.5 / SAMPLE_STEP)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# Set device\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(device)\n",
"\n",
"# Hyperparameters\n",
2024-12-27 16:12:50 +01:00
"input_size = len(in_fields) #in_d\n",
"hidden_size = 64 # hidden_d\n",
"num_layers = 1 # num_hidden\n",
"output_size = len(out_fields) # out_d\n",
2024-11-17 19:39:32 +01:00
"learning_rate = 0.005 #0.01 #0.005\n",
2024-12-27 16:12:50 +01:00
"batch_size = 512\n",
"num_epochs = 1000\n"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"cache_path = pathlib.Path(CACHE_DIR)\n",
"cache_path.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Samping 1/5, of 412098 items\n",
"Done sampling kept 83726 items\n"
]
}
],
"source": [
"from pathlib import Path\n",
"from trap.tools import load_tracks_from_csv\n",
2024-12-27 16:12:50 +01:00
"from trap.tools import filter_short_tracks, normalise_position\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"data= load_tracks_from_csv(Path(SRC_CSV), FPS, GRID_SIZE, SAMPLE_STEP )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# create x-norm, y_norm columns\n",
"data, mu, std = normalise_position(data)\n",
"data = filter_short_tracks(data, window+1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is a bit crappy because it has different frame step: ranging from predominantly 1 and 2 to sometimes have 3 and 4 as well. This inevitabily leads to difference in speed caluclations"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1606 training tracks, 402 test tracks\n"
]
}
],
"source": [
"track_ids = data.index.unique('track_id').to_numpy()\n",
"np.random.shuffle(track_ids)\n",
"test_offset_idx = int(len(track_ids) * .8)\n",
"training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n",
"print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"here, draw out a sample track to see if it looks alright. **unfortunately the imate isn't mapped properly**."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4789\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGdCAYAAAAvwBgXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3wd1bW2n5nTm456lyzLsuTeG9iAbTC9mh5qgFBCKqR8yeUGuClAIJQAKYQaamimg23ce7dlucmyqtV7O/3M/v4YVUs2dnCRzX74CVlT9xwdnXln7bXepQghBBKJRCKRSCQnAPVED0AikUgkEsl3FylEJBKJRCKRnDCkEJFIJBKJRHLCkEJEIpFIJBLJCUMKEYlEIpFIJCcMKUQkEolEIpGcMKQQkUgkEolEcsKQQkQikUgkEskJw3iiB3AoNE2joqICl8uFoignejgSiUQikUgOAyEEra2tJCcno6qHjnkMaCFSUVFBWlraiR6GRCKRSCSS/4KysjJSU1MPuc2AFiIulwvQLyQiIuIEj0YikUgkEsnh0NLSQlpaWtd9/FAMaCHSOR0TEREhhYhEIpFIJCcZh5NWIZNVJRKJRCKRnDCkEJFIJBKJRHLCkEJEIpFIJBLJCUMKEYlEIpFIJCcMKUQkEolEIpGcMI6pEHnkkUeYPHkyLpeL+Ph4Lr/8cvbs2XMsTymRSCQSieQk4pgKkWXLlnHvvfeydu1aFi5cSDAY5Nxzz6W9vf1YnlYikUgkEslJgiKEEMfrZLW1tcTHx7Ns2TLOPPPMb9y+paUFt9tNc3Oz9BGRSCQSieQk4Uju38c1R6S5uRmA6Ojoftf7/X5aWlp6fUkkEslARQjBhqoNHMfnOYnklOO4CRFN0/jZz37G9OnTGTVqVL/bPPLII7jd7q4v2WdGIpEMZFaWr+S2+bexqmLViR6KRHLSctyEyL333kteXh7vvPPOQbf5zW9+Q3Nzc9dXWVnZ8RqeRCKRHDELSxYCsKB4wQkeiURy8nJces386Ec/4rPPPmP58uWH7MJnsViwWCzHY0gSiURyxGhC4z97/kNroBWAL4u+7Pqe6tI/21xmF9fmXIuqSHcEieRwOKZCRAjBj3/8Y+bNm8fSpUsZPHjwsTydRCKRHFO8IS/Pb3me5kBzr+X+sJ/ntjyHQOA2u7l0yKU4TI4TNEqJ5OTimEr2e++9lzfeeIO33noLl8tFVVUVVVVVeL3eY3laiUQiOSbYjDZuHHEjiY7EXstFx3+JjkRuGnETNqPtBI1QIjn5OKbluwdr//vKK69w6623fuP+snxXIpEMJNqD7Zz3/nl9IiI9cZvdzL9qvoyISL7TDJjyXSFEv1+HI0IkEolkoOEwOXjvkvfIiszqd31WZBbvX/r+txIhQghWNbbKkmDJdwaZTSWRSCQdHI4ISHImMT15er/rpqdM7zNtc6Qsbmjlyq37WNLQ+q2OI5GcLEghIpFIJB0crgjorJA5kDTXt/c++qy2qdd3ieRU57iU70okEsnJQE8RMDum97x2z9LdVeXdBmZGxUhaRBpFzUV8XfI1V2dffUSlu5oQvFpeR0soDMAn1Y0AfFrTRLrVDECE0cCtKbGoB8m7k0hOZqQQkUgk31mORAQcrHQ3LMIUNRcBsL12O96Q94hyRDxhjT8XVdHUMYZO2sMajxVVIYBIo4FrEqNxGg3/7aVKJAOW49r07kiRVTMSieRY0hYKM2nNzj4iwABo0CUCNp42AqfRQGVbJb9a/iu21m7t2jamWVAfAePix/P4WY//Vzki+30B7tlRzIYWT59145w2Xho9mJQOYQR6LsvqpjZOj3QetDpRIjmRDJiqGYlEIhnIOI0Gvp6cw+QIe6/lYXQRMjnCzqLJOV2RiCRnEi+f93Ivn5ChFYIpJWZePv/l/zpRNdVq5oPxWdjVvh/J+R4/r1fU0xwMdS2TCa29EULQvm59v0nGmqbx+s7XCYfDskHhAEUKEYlE8p3mYCLApih8OH5or0gEwPY6ffqlk8pohfE7fWyv3f6tbnRbWjx4NK3Pco+m8XRJNVPW7uLZkmraw2GZ0HoA7StWUHrLLbSvXNln3T9z/8mfN/yZe76+p1eDwoN1TpYdlY8/UohIJJLvPP2JAK8QbGlp77VMaBqFLz/PFas0bt+bwsgSjYvXa0zdJVj08sPcNv82vnrjD4h+BMU3saC+BYALYt2snTac82PdAExx28mxW2gOhfljYSVjVu3gg6ruXJani6t4uriKl/fXon1Hb54t8+f3+t6T9/LfA2BN1Rqgu0HhwTony47Kxx+ZrCqRSL7z9BQBZ8e4+MWe/QC8XlnPlEhn13aax8uID7cxslUjZChlcIIguwJaLVBWWwDJKvPzPuRcz30YnEdmanZeTAQjnTauiI9EURReGZXBH/ZV8nxZDf8aMYif7C7Fqwnaw90ix/MdTWgVmkbj22+jtepTUy1fftX1XU1N5g+mhbQagzS4FGq9tb32/ajgI3JrczEZTIAuTGakzOha39lReWHJwl7LJccOKUQkEsl3ngNFwMrGNj6qaWJjcztBTWBS9YRQg9PB0I8/ofz+X9C+ZQsbhipkVwhqI2D5SH2bdWPMvFT4JnBknXinRDqZ0uNnRVFoCOl5IS/sryXCaMAbCPXapzPFdnKEnX+MzPhuiBAhaFu+gtq/PovW3LuCSbS3s/fF51j2IxVQoJ8UGoFgX/O+rp+/KPyCem89AS2AxWBhfeV6QBcoKc4UQHZUPtbIqhmJRCI5gNpAkDPX7aYxFOaBzCR+NCiha50mNN7Ne4sh33+M+27V+PvzGh+dpvCfs/qKgCPtO9OznFgIwdMlNfh7fEQ7DCresEbPiR+7qrLnjNFdYulUp235csruvIukRx+h4d+v49+5s9d6rxnu+okJn+nb3doMigFNaF0dlWX/oCNDVs1IJBLJtyDObOLBrGQAniiuosjj77J/9wQ9fPXFszjbQkzcB1szFS7aILD6e9/4RseOPuK+M52eIo8WVfFYcXUvEQK6t8iB2SceTeuTy3Iq05kH0rpkCeG6uj7r7QYbH1/1GdmR2Ud0XKPSe4IgLMIIBOPixn3r/kGSQyOFiEQikfTDtYnRnBHlxKcJfrmnjEX1LVy5dR/rWzV+F74QgMhWwcqRCtYgXLmqWyJYVAuvXfDaEZfzHqycuJNEs36zPDChdX5HjsupiNA0Gt58k7p//IO6f/yD1q90IdK2YCGhmpq+23u9RBbWMXfo3CM6T0iE+iyzGW3fqixbcnjIHBGJRCLpB0VReDwnjVnrd7OyqY1gR3Tis9om/nD+JfhGjmV91T9pEXV4zW1csl7w/gyB36zg1/zk1eUxPn78EZ+3s5w4c1kuPW+NdlXlbyMGURUI9UponVfTRKrFdJSueuChebzd+SCKon8B9IwWGY1kvP8+dc8+S9uiRbQtXsy7WSu+9bm9IS+5tblMTJj4rY8lOTgyIiKRSCT9oAnB4voWpnVUzaxr1qc/Pq1p4oXoZB4dko7XOgcTTtZnK6gCZm/tjorMK5j3X597S4uHA5/PPZqGUVGYmxDV5aaqdPw8JdJ5WJ2DT0YMTgeZ8z7ENn68Lj4OKI02DRrEkAXzsQ3LIfW5Z0l+/HGcs2eTYE84yBGPjN+t+h0hrW+0RHL0kEJEIpFI+qEzX+NA99L2cJhHiyr5d61Kq3ME2U1jWDlS/yi9ZpVCiqLnlhQ3F//X517QY6rlj1nJhzUFcyq7rZqSkxn02qsoNluv5YrZzJDPPsWcrL/miqLgvuRi7BMmcM+4e7h77N2ckXLGtzp3aWspDb6Gb3UMyaGRQkQikUj64WD5GhoKCDh9by4Xb19Lki+K5th0Gh3g8GmcvTURg2JgS80W1lWu+6/OfV5MBDa1O+rxyqgM/jZiEOfFHLz64FR3W/Vu347wenstE4EA3u3b+91+fPx47h13Lz8Y8wPOSj0LgNnps/n8is9Jd6Uf1jntRjtvXvgm8fb4bzd4ySGROSISiURyEDrzNYatyOvlvGoUGiMrCrt+HtKSzeoRxVy0QXC5z4bHfRbvNy3miY1P8M5F72BQj8z
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import random\n",
"# if H:\n",
"# img_src = \"../DATASETS/hof/webcam20231103-2.png\"\n",
"# # dst = cv2.warpPerspective(img_src,H,(2500,1920))\n",
"# src_img = cv2.imread(img_src)\n",
"# print(src_img.shape)\n",
"# h1,w1 = src_img.shape[:2]\n",
"# corners = np.float32([[0,0], [w1, 0], [0, h1], [w1, h1]])\n",
"\n",
"# print(corners)\n",
"# corners_projected = cv2.perspectiveTransform(corners.reshape((-1,4,2)), H)[0]\n",
"# print(corners_projected)\n",
"# [xmin, ymin] = np.int32(corners_projected.min(axis=0).ravel() - 0.5)\n",
"# [xmax, ymax] = np.int32(corners_projected.max(axis=0).ravel() + 0.5)\n",
"# print(xmin, xmax, ymin, ymax)\n",
"\n",
"# dst = cv2.warpPerspective(src_img,H, (xmax, ymax))\n",
"# def plot_track(track_id: int):\n",
"# plt.gca().invert_yaxis()\n",
"\n",
"# plt.imshow(dst, origin='lower', extent=[xmin/100-mean_x, xmax/100-mean_x, ymin/100-mean_y, ymax/100-mean_y])\n",
"# # plot scatter plot with x and y data \n",
" \n",
"# ax = plt.scatter(\n",
"# filtered_data.loc[track_id,:]['proj_x'],\n",
"# filtered_data.loc[track_id,:]['proj_y'],\n",
"# marker=\"*\") \n",
"# ax.axes.invert_yaxis()\n",
"# plt.plot(\n",
"# filtered_data.loc[track_id,:]['proj_x'],\n",
"# filtered_data.loc[track_id,:]['proj_y']\n",
"# )\n",
"# else:\n",
"def plot_track(track_id: int):\n",
" ax = plt.scatter(\n",
" data.loc[track_id,:]['x_norm'],\n",
" data.loc[track_id,:]['y_norm'],\n",
" marker=\"*\") \n",
" plt.plot(\n",
" data.loc[track_id,:]['x_norm'],\n",
" data.loc[track_id,:]['y_norm']\n",
" )\n",
"\n",
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
"# _track_id = 2188\n",
"_track_id = random.choice(track_ids)\n",
"print(_track_id)\n",
"plot_track(_track_id)\n",
"\n",
"for track_id in random.choices(track_ids, k=100):\n",
" plot_track(track_id)\n",
" \n",
"# print(mean_x, mean_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now make the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2024-11-17 19:39:32 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2024-12-27 16:12:50 +01:00
" <th></th>\n",
2024-11-17 19:39:32 +01:00
" <th>l</th>\n",
" <th>t</th>\n",
" <th>w</th>\n",
" <th>h</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>state</th>\n",
" <th>diff</th>\n",
2024-12-27 16:12:50 +01:00
" <th>x_raw</th>\n",
" <th>y_raw</th>\n",
2024-11-17 19:39:32 +01:00
" <th>...</th>\n",
" <th>vx</th>\n",
" <th>vy</th>\n",
" <th>ax</th>\n",
" <th>ay</th>\n",
" <th>v</th>\n",
" <th>a</th>\n",
" <th>heading</th>\n",
" <th>d_heading</th>\n",
2024-12-27 16:12:50 +01:00
" <th>x_norm</th>\n",
" <th>y_norm</th>\n",
" </tr>\n",
" <tr>\n",
" <th>track_id</th>\n",
" <th>frame_id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
2024-11-17 19:39:32 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-27 16:12:50 +01:00
" <th rowspan=\"5\" valign=\"top\">1</th>\n",
" <th>342.0</th>\n",
" <td>1393.736572</td>\n",
" <td>0.000000</td>\n",
" <td>67.613647</td>\n",
" <td>121.391151</td>\n",
" <td>13.244408</td>\n",
" <td>2.414339</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>NaN</td>\n",
2024-12-27 16:12:50 +01:00
" <td>13.5</td>\n",
" <td>2.5</td>\n",
2024-11-17 19:39:32 +01:00
" <td>...</td>\n",
2024-12-27 16:12:50 +01:00
" <td>6.143418e-01</td>\n",
" <td>1.389160</td>\n",
" <td>-1.422342e+00</td>\n",
" <td>0.453213</td>\n",
" <td>1.518941</td>\n",
" <td>0.142097</td>\n",
" <td>66.143088</td>\n",
" <td>5.536579e+01</td>\n",
" <td>0.353449</td>\n",
" <td>-1.768217</td>\n",
2024-11-17 19:39:32 +01:00
" </tr>\n",
" <tr>\n",
2024-12-27 16:12:50 +01:00
" <th>347.0</th>\n",
" <td>1393.844849</td>\n",
" <td>12.691238</td>\n",
" <td>86.482910</td>\n",
" <td>156.264786</td>\n",
" <td>13.500384</td>\n",
" <td>2.993156</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>13.5</td>\n",
" <td>3.0</td>\n",
" <td>...</td>\n",
" <td>6.143418e-01</td>\n",
" <td>1.389160</td>\n",
" <td>-1.422342e+00</td>\n",
" <td>0.453213</td>\n",
" <td>1.518941</td>\n",
" <td>0.142097</td>\n",
" <td>66.143088</td>\n",
" <td>5.536579e+01</td>\n",
" <td>0.414443</td>\n",
" <td>-1.574517</td>\n",
" </tr>\n",
" <tr>\n",
" <th>352.0</th>\n",
" <td>1405.273438</td>\n",
" <td>36.675903</td>\n",
" <td>90.329956</td>\n",
" <td>176.461975</td>\n",
" <td>13.509425</td>\n",
" <td>3.650656</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>13.5</td>\n",
" <td>3.5</td>\n",
" <td>...</td>\n",
" <td>2.169933e-02</td>\n",
" <td>1.577999</td>\n",
" <td>-1.422342e+00</td>\n",
" <td>0.453213</td>\n",
" <td>1.578149</td>\n",
" <td>0.142097</td>\n",
" <td>89.212166</td>\n",
" <td>5.536579e+01</td>\n",
" <td>0.416598</td>\n",
" <td>-1.354485</td>\n",
" </tr>\n",
" <tr>\n",
" <th>357.0</th>\n",
" <td>1421.215698</td>\n",
" <td>76.261253</td>\n",
" <td>91.465088</td>\n",
" <td>181.133682</td>\n",
" <td>13.500221</td>\n",
" <td>4.282279</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>13.5</td>\n",
" <td>4.5</td>\n",
" <td>...</td>\n",
" <td>-2.209058e-02</td>\n",
" <td>1.515896</td>\n",
" <td>-1.050958e-01</td>\n",
" <td>-0.149049</td>\n",
" <td>1.516057</td>\n",
" <td>-0.149020</td>\n",
" <td>90.834891</td>\n",
" <td>3.894540e+00</td>\n",
" <td>0.414404</td>\n",
" <td>-1.143113</td>\n",
" </tr>\n",
" <tr>\n",
" <th>362.0</th>\n",
" <td>1438.374268</td>\n",
" <td>115.362549</td>\n",
" <td>84.298584</td>\n",
" <td>172.143616</td>\n",
" <td>13.499658</td>\n",
" <td>4.743787</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>13.5</td>\n",
" <td>4.5</td>\n",
2024-11-17 19:39:32 +01:00
" <td>...</td>\n",
2024-12-27 16:12:50 +01:00
" <td>-1.349331e-03</td>\n",
" <td>1.107618</td>\n",
" <td>4.977900e-02</td>\n",
" <td>-0.979866</td>\n",
" <td>1.107619</td>\n",
" <td>-0.980250</td>\n",
" <td>90.069799</td>\n",
" <td>-1.836220e+00</td>\n",
" <td>0.414270</td>\n",
" <td>-0.988670</td>\n",
2024-11-17 19:39:32 +01:00
" </tr>\n",
" <tr>\n",
2024-12-27 16:12:50 +01:00
" <th>...</th>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2024-11-17 19:39:32 +01:00
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
2024-12-27 16:12:50 +01:00
" <th rowspan=\"5\" valign=\"top\">5030</th>\n",
" <th>32702.0</th>\n",
" <td>1705.054635</td>\n",
" <td>749.467887</td>\n",
" <td>132.149004</td>\n",
" <td>182.105042</td>\n",
" <td>14.000000</td>\n",
" <td>10.495261</td>\n",
2024-11-17 19:39:32 +01:00
" <td>1.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>14.0</td>\n",
" <td>10.5</td>\n",
" <td>...</td>\n",
" <td>1.654143e-12</td>\n",
" <td>-0.029145</td>\n",
" <td>4.967546e-11</td>\n",
" <td>0.657171</td>\n",
" <td>0.029145</td>\n",
" <td>-0.657171</td>\n",
" <td>270.000000</td>\n",
" <td>1.644812e-08</td>\n",
" <td>0.533492</td>\n",
" <td>0.936057</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32707.0</th>\n",
" <td>1703.756025</td>\n",
" <td>749.703112</td>\n",
" <td>131.216670</td>\n",
" <td>181.961914</td>\n",
" <td>14.000000</td>\n",
" <td>10.499609</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>14.0</td>\n",
" <td>10.5</td>\n",
" <td>...</td>\n",
" <td>7.418066e-13</td>\n",
" <td>0.010435</td>\n",
" <td>-2.189608e-12</td>\n",
" <td>0.094992</td>\n",
" <td>0.010435</td>\n",
" <td>-0.044905</td>\n",
" <td>90.000000</td>\n",
" <td>-4.320000e+02</td>\n",
" <td>0.533492</td>\n",
" <td>0.937512</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32712.0</th>\n",
" <td>1702.457415</td>\n",
" <td>749.938337</td>\n",
" <td>130.284337</td>\n",
" <td>181.818787</td>\n",
" <td>14.000000</td>\n",
" <td>10.500165</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>14.0</td>\n",
" <td>10.5</td>\n",
" <td>...</td>\n",
" <td>-4.263256e-14</td>\n",
" <td>0.001334</td>\n",
" <td>-1.882654e-12</td>\n",
" <td>-0.021841</td>\n",
" <td>0.001334</td>\n",
" <td>-0.021841</td>\n",
" <td>90.000000</td>\n",
" <td>1.416898e-08</td>\n",
" <td>0.533492</td>\n",
" <td>0.937698</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32717.0</th>\n",
" <td>1701.158805</td>\n",
" <td>750.173562</td>\n",
" <td>129.352003</td>\n",
" <td>181.675659</td>\n",
" <td>14.000000</td>\n",
" <td>10.500019</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>14.0</td>\n",
" <td>10.5</td>\n",
" <td>...</td>\n",
" <td>-2.984279e-14</td>\n",
" <td>-0.000350</td>\n",
" <td>3.069545e-14</td>\n",
" <td>-0.004042</td>\n",
" <td>0.000350</td>\n",
" <td>-0.002362</td>\n",
" <td>270.000000</td>\n",
" <td>4.320000e+02</td>\n",
" <td>0.533492</td>\n",
" <td>0.937649</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32722.0</th>\n",
" <td>1702.384766</td>\n",
" <td>750.754517</td>\n",
" <td>123.435425</td>\n",
" <td>180.945618</td>\n",
" <td>14.000000</td>\n",
" <td>10.499985</td>\n",
2024-11-17 19:39:32 +01:00
" <td>2.0</td>\n",
" <td>5.0</td>\n",
2024-12-27 16:12:50 +01:00
" <td>14.0</td>\n",
" <td>10.5</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>-0.000082</td>\n",
" <td>7.162271e-14</td>\n",
" <td>0.000644</td>\n",
" <td>0.000082</td>\n",
" <td>-0.000644</td>\n",
" <td>270.000000</td>\n",
" <td>1.172430e-08</td>\n",
" <td>0.533492</td>\n",
" <td>0.937638</td>\n",
2024-11-17 19:39:32 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-12-27 16:12:50 +01:00
"<p>80035 rows × 24 columns</p>\n",
2024-11-17 19:39:32 +01:00
"</div>"
],
"text/plain": [
2024-12-27 16:12:50 +01:00
" l t w h x \\\n",
"track_id frame_id \n",
"1 342.0 1393.736572 0.000000 67.613647 121.391151 13.244408 \n",
" 347.0 1393.844849 12.691238 86.482910 156.264786 13.500384 \n",
" 352.0 1405.273438 36.675903 90.329956 176.461975 13.509425 \n",
" 357.0 1421.215698 76.261253 91.465088 181.133682 13.500221 \n",
" 362.0 1438.374268 115.362549 84.298584 172.143616 13.499658 \n",
"... ... ... ... ... ... \n",
"5030 32702.0 1705.054635 749.467887 132.149004 182.105042 14.000000 \n",
" 32707.0 1703.756025 749.703112 131.216670 181.961914 14.000000 \n",
" 32712.0 1702.457415 749.938337 130.284337 181.818787 14.000000 \n",
" 32717.0 1701.158805 750.173562 129.352003 181.675659 14.000000 \n",
" 32722.0 1702.384766 750.754517 123.435425 180.945618 14.000000 \n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" y state diff x_raw y_raw ... vx \\\n",
"track_id frame_id ... \n",
"1 342.0 2.414339 2.0 NaN 13.5 2.5 ... 6.143418e-01 \n",
" 347.0 2.993156 2.0 5.0 13.5 3.0 ... 6.143418e-01 \n",
" 352.0 3.650656 2.0 5.0 13.5 3.5 ... 2.169933e-02 \n",
" 357.0 4.282279 2.0 5.0 13.5 4.5 ... -2.209058e-02 \n",
" 362.0 4.743787 2.0 5.0 13.5 4.5 ... -1.349331e-03 \n",
"... ... ... ... ... ... ... ... \n",
"5030 32702.0 10.495261 1.0 5.0 14.0 10.5 ... 1.654143e-12 \n",
" 32707.0 10.499609 1.0 5.0 14.0 10.5 ... 7.418066e-13 \n",
" 32712.0 10.500165 1.0 5.0 14.0 10.5 ... -4.263256e-14 \n",
" 32717.0 10.500019 1.0 5.0 14.0 10.5 ... -2.984279e-14 \n",
" 32722.0 10.499985 2.0 5.0 14.0 10.5 ... 0.000000e+00 \n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" vy ax ay v a \\\n",
"track_id frame_id \n",
"1 342.0 1.389160 -1.422342e+00 0.453213 1.518941 0.142097 \n",
" 347.0 1.389160 -1.422342e+00 0.453213 1.518941 0.142097 \n",
" 352.0 1.577999 -1.422342e+00 0.453213 1.578149 0.142097 \n",
" 357.0 1.515896 -1.050958e-01 -0.149049 1.516057 -0.149020 \n",
" 362.0 1.107618 4.977900e-02 -0.979866 1.107619 -0.980250 \n",
"... ... ... ... ... ... \n",
"5030 32702.0 -0.029145 4.967546e-11 0.657171 0.029145 -0.657171 \n",
" 32707.0 0.010435 -2.189608e-12 0.094992 0.010435 -0.044905 \n",
" 32712.0 0.001334 -1.882654e-12 -0.021841 0.001334 -0.021841 \n",
" 32717.0 -0.000350 3.069545e-14 -0.004042 0.000350 -0.002362 \n",
" 32722.0 -0.000082 7.162271e-14 0.000644 0.000082 -0.000644 \n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" heading d_heading x_norm y_norm \n",
"track_id frame_id \n",
"1 342.0 66.143088 5.536579e+01 0.353449 -1.768217 \n",
" 347.0 66.143088 5.536579e+01 0.414443 -1.574517 \n",
" 352.0 89.212166 5.536579e+01 0.416598 -1.354485 \n",
" 357.0 90.834891 3.894540e+00 0.414404 -1.143113 \n",
" 362.0 90.069799 -1.836220e+00 0.414270 -0.988670 \n",
"... ... ... ... ... \n",
"5030 32702.0 270.000000 1.644812e-08 0.533492 0.936057 \n",
" 32707.0 90.000000 -4.320000e+02 0.533492 0.937512 \n",
" 32712.0 90.000000 1.416898e-08 0.533492 0.937698 \n",
" 32717.0 270.000000 4.320000e+02 0.533492 0.937649 \n",
" 32722.0 270.000000 1.172430e-08 0.533492 0.937638 \n",
"\n",
"[80035 rows x 24 columns]"
2024-11-17 19:39:32 +01:00
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
2024-12-27 16:12:50 +01:00
"source": [
"# a=filtered_data.loc[1]\n",
"# min(a.index.tolist())\n",
"data"
]
2024-11-17 19:39:32 +01:00
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 11,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
2024-12-27 16:12:50 +01:00
"name": "stdout",
"output_type": "stream",
"text": [
"x False\n",
"y False\n",
"vx False\n",
"vy False\n",
"ax False\n",
"ay False\n",
"dx False\n",
"dy False\n"
]
2024-11-17 19:39:32 +01:00
}
],
"source": [
2024-12-27 16:12:50 +01:00
"for field in in_fields + out_fields:\n",
" print(field, data[field].isnull().values.any())"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 12,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
" 0%| | 0/1606 [00:00<?, ?it/s]"
2024-11-17 19:39:32 +01:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"100%|██████████| 1606/1606 [00:27<00:00, 58.17it/s]\n",
"100%|██████████| 402/402 [00:06<00:00, 60.38it/s]\n"
2024-11-17 19:39:32 +01:00
]
}
],
"source": [
2024-12-27 16:12:50 +01:00
"def create_dataset(data, track_ids, window, only_last=False):\n",
" X, y, = [], []\n",
" factor = SAMPLE_STEP if SAMPLE_STEP is not None else 1\n",
" for track_id in tqdm(track_ids):\n",
" df = data.loc[track_id]\n",
" # print(df)\n",
" start_frame = min(df.index.tolist())\n",
" for step in range(len(df)-window-1):\n",
" i = int(start_frame) + (step*factor)\n",
" # print(step, int(start_frame), i)\n",
" feature = df.loc[i:i+(window*factor)][in_fields]\n",
" # target = df.loc[i+1:i+window+1][out_fields]\n",
" # print(i, window*factor, factor, i+window*factor+factor, df['idx_in_track'])\n",
" # print(i+window*factor+factor)\n",
" if only_last:\n",
" target = df.loc[i+window*factor+factor][out_fields]\n",
" else:\n",
" target = df.loc[i+factor:i+window*factor+factor][out_fields]\n",
"\n",
" X.append(feature.values)\n",
" y.append(target.values)\n",
" \n",
" return torch.tensor(np.array(X), dtype=torch.float), torch.tensor(np.array(y), dtype=torch.float)\n",
"\n",
"X_train, y_train = create_dataset(data, training_ids, window)\n",
"X_test, y_test = create_dataset(data, test_ids, window)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 13,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"X_train, y_train = X_train.to(device=device), y_train.to(device=device)\n",
"X_test, y_test = X_test.to(device=device), y_test.to(device=device)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 106,
2024-11-17 19:39:32 +01:00
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
2024-11-17 19:39:32 +01:00
"source": [
2024-12-27 16:12:50 +01:00
"from torch.utils.data import TensorDataset, DataLoader\n",
"dataset_train = TensorDataset(X_train, y_train)\n",
"loader_train = DataLoader(dataset_train, shuffle=True, batch_size=batch_size)\n",
"dataset_test = TensorDataset(X_test, y_test)\n",
"loader_test = DataLoader(dataset_test, shuffle=False, batch_size=batch_size)"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"cell_type": "markdown",
2024-11-17 19:39:32 +01:00
"metadata": {},
"source": [
2024-12-27 16:12:50 +01:00
"Model give output for all timesteps, this should improve training. But we use only the last timestep for the prediction process"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RNN"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class SimpleRnn(nn.Module):\n",
" def __init__(self, in_d=2, out_d=2, hidden_d=4, num_hidden=1):\n",
" super(SimpleRnn, self).__init__()\n",
" self.rnn = nn.RNN(input_size=in_d, hidden_size=hidden_d, num_layers=num_hidden)\n",
" self.fc = nn.Linear(hidden_d, out_d)\n",
"\n",
" def forward(self, x, h0):\n",
" r, h = self.rnn(x, h0)\n",
" # r = r[:, -1,:]\n",
" y = self.fc(r) # no activation on the output\n",
" return y, h\n",
"rnn = SimpleRnn(input_size, output_size, hidden_size, num_layers).to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LSTM"
]
},
{
"cell_type": "markdown",
2024-11-17 19:39:32 +01:00
"metadata": {},
"source": [
2024-12-27 16:12:50 +01:00
"For optional LSTM-GAN, see https://discuss.pytorch.org/t/how-to-use-lstm-to-construct-gan/12419\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"Or VAE (variational Auto encoder):\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"> The only constraint on the latent vector representation for traditional autoencoders is that latent vectors should be easily decodable back into the original image. As a result, the latent space $Z$ can become disjoint and non-continuous. Variational autoencoders try to solve this problem. [Alexander van de Kleut](https://avandekleut.github.io/vae/)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"For LSTM based generative VAE: https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
"\n",
"http://web.archive.org/web/20210119121802/https://towardsdatascience.com/time-series-generation-with-vae-lstm-5a6426365a1c?gi=29d8b029a386\n",
"\n",
"https://youtu.be/qJeaCHQ1k2w?si=30aAdqqwvz0DpR-x&t=687 VAE generate mu and sigma of a Normal distribution. Thus, they don't map the input to a single point, but a gausian distribution."
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 328,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"class LSTMModel(nn.Module):\n",
" # input_size : number of features in input at each time step\n",
" # hidden_size : Number of LSTM units \n",
" # num_layers : number of LSTM layers \n",
" def __init__(self, input_size, hidden_size, num_layers): \n",
" super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n",
" # We _could_ train the h0: https://discuss.pytorch.org/t/learn-initial-hidden-state-h0-for-rnn/10013 \n",
" # self.lin1 = nn.Linear(input_size, hidden_size)\n",
" self.num_layers = num_layers\n",
" self.hidden_size = hidden_size\n",
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n",
" self.linear = nn.Linear(hidden_size, output_size)\n",
" # self.activation_v = nn.LeakyReLU(.01)\n",
" # self.activation_heading = torch.remainder()\n",
"\n",
2024-11-17 19:39:32 +01:00
" \n",
2024-12-27 16:12:50 +01:00
" def get_hidden_state(self, batch_size, device):\n",
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" return (h, c)\n",
"\n",
" def forward(self, x, hidden_state): # defines forward pass of the neural network\n",
" # out = self.lin1(x)\n",
" \n",
" out, hidden_state = self.lstm(x, hidden_state)\n",
" # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n",
" # print(out.shape)\n",
" # TODO)) Might want to remove this below: as it might improve training\n",
" # out = out[:, -1,:]\n",
" # print(out.shape)\n",
" out = self.linear(out)\n",
" \n",
" # torch.remainder(out[1], 360)\n",
" # print('o',out.shape)\n",
" return out, hidden_state\n",
"\n",
"lstm = LSTMModel(input_size, hidden_size, num_layers).to(device)\n"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 329,
2024-11-17 19:39:32 +01:00
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
2024-11-17 19:39:32 +01:00
"source": [
2024-12-27 16:12:50 +01:00
"# model = rnn\n",
"model = lstm\n"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"cell_type": "code",
"execution_count": 330,
2024-11-17 19:39:32 +01:00
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
2024-11-17 19:39:32 +01:00
"source": [
2024-12-27 16:12:50 +01:00
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"loss_fn = nn.MSELoss()"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 331,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"def evaluate():\n",
" # toggle evaluation mode\n",
" model.eval()\n",
" with torch.no_grad():\n",
" batch_size, seq_len, feature_dim = X_train.shape\n",
" y_pred, _ = model(\n",
" X_train.to(device=device),\n",
" model.get_hidden_state(batch_size, device)\n",
" )\n",
" train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n",
" # print(y_pred)\n",
"\n",
" batch_size, seq_len, feature_dim = X_test.shape\n",
" y_pred, _ = model(\n",
" X_test.to(device=device),\n",
" model.get_hidden_state(batch_size, device)\n",
" )\n",
" # print(loss_fn(y_pred, y_test))\n",
" test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n",
" print(\"Epoch ??: train RMSE %.4f, test RMSE %.4f\" % ( train_rmse, test_rmse))\n",
"\n",
"def load_most_recent():\n",
" paths = list(cache_path.glob(f\"checkpoint-{model._get_name()}_*.pt\"))\n",
" if len(paths) < 1:\n",
" print('Nothing found to load')\n",
" return None, None\n",
" paths.sort()\n",
"\n",
" print(f\"Loading {paths[-1]}\")\n",
" return load_cache(path=paths[-1])\n",
"\n",
"def load_cache(epoch=None, path=None):\n",
" if path is None:\n",
" if epoch is None:\n",
" raise RuntimeError(\"Either path or epoch must be given\")\n",
" path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n",
" else:\n",
" print (path.stem)\n",
" epoch = int(path.stem[-5:])\n",
"\n",
" cached = torch.load(path)\n",
" \n",
" optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
" model.load_state_dict(cached['model_state_dict'])\n",
" return epoch, cached['loss']\n",
" \n",
"\n",
"def cache(epoch, loss):\n",
" path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n",
" print(f\"Cache to {path}\")\n",
" torch.save({\n",
" 'epoch': epoch,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'loss': loss,\n",
" }, path)\n"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"cell_type": "markdown",
2024-11-17 19:39:32 +01:00
"metadata": {},
"source": [
2024-12-27 16:12:50 +01:00
"TODO)) See [this notebook](https://www.cs.toronto.edu/~lczhang/aps360_20191/lec/w08/rnn.html) For initialization (with random or not) and the use of GRU"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 332,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"Loading EXPERIMENTS/cache/hof2/checkpoint-LSTMModel_01000.pt\n",
"checkpoint-LSTMModel_01000\n",
"starting from epoch 1000 (loss: 0.014368701726198196)\n",
"Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n"
2024-11-17 19:39:32 +01:00
]
2024-12-27 16:12:50 +01:00
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]"
]
},
2024-11-17 19:39:32 +01:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-11-17 19:39:32 +01:00
]
}
],
"source": [
2024-12-27 16:12:50 +01:00
"start_epoch, loss = load_most_recent()\n",
"if start_epoch is None:\n",
" start_epoch = 0\n",
"else:\n",
" print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n",
" evaluate()\n",
"\n",
"loss_log = []\n",
"# Train Network\n",
"for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n",
" # toggle train mode\n",
" model.train()\n",
" for batch_idx, (x, targets) in enumerate(loader_train):\n",
" # Get x to cuda if possible\n",
" x = x.to(device=device).squeeze(1)\n",
" targets = targets.to(device=device)\n",
"\n",
" # forward\n",
" scores, _ = model(\n",
" x,\n",
" torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device),\n",
" torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device)\n",
" )\n",
" # print(scores)\n",
" loss = loss_fn(scores, targets)\n",
"\n",
" # backward\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"\n",
" # gradient descent update step/adam step\n",
" optimizer.step()\n",
"\n",
" loss_log.append(loss.item())\n",
"\n",
" if epoch % 5 != 0:\n",
" continue\n",
"\n",
" cache(epoch, loss)\n",
" evaluate()\n",
"\n",
"evaluate()"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"cell_type": "code",
"execution_count": 333,
2024-11-17 19:39:32 +01:00
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
2024-11-17 19:39:32 +01:00
"source": [
2024-12-27 16:12:50 +01:00
"# print(loss)\n",
"# print(len(loss_log))\n",
"# plt.plot(loss_log)\n",
"# plt.ylabel('Loss')\n",
"# plt.xlabel('iteration')\n",
"# plt.show()"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 335,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"torch.Size([49999, 9, 2]) torch.Size([49999, 9, 2])\n"
2024-11-17 19:39:32 +01:00
]
}
],
"source": [
2024-12-27 16:12:50 +01:00
"model.eval()\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"with torch.no_grad():\n",
" y_pred, _ = model(X_train.to(device=device),\n",
" model.get_hidden_state(X_train.shape[0], device))\n",
2024-11-17 19:39:32 +01:00
" \n",
2024-12-27 16:12:50 +01:00
" print(y_pred.shape, y_train.shape)\n",
"# y_train, y_pred"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"cell_type": "code",
"execution_count": 336,
2024-11-17 19:39:32 +01:00
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [],
2024-11-17 19:39:32 +01:00
"source": [
2024-12-27 16:12:50 +01:00
"import scipy\n",
"\n",
"def ceil_away_from_0(a):\n",
" return np.sign(a) * np.ceil(np.abs(a))\n"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 343,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"def predict_and_plot(model, feature, steps = 50):\n",
" lenght = feature.shape[0]\n",
"\n",
" dt = (1/ FPS) * SAMPLE_STEP\n",
"\n",
" trajectory = feature\n",
"\n",
" # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n",
" # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n",
" with torch.no_grad():\n",
" # h = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n",
" # c = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n",
" h = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n",
" c = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n",
" hidden_state = (h, c)\n",
" # X = torch.tensor([feature], dtype=torch.float).to(device)\n",
" # y, (h, c) = model(X, h, c)\n",
" for i in range(steps):\n",
" # predict_f = scipy.ndimage.uniform_filter(feature)\n",
" # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n",
" # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n",
" # bathc size of one, so feature as single item in array\n",
" # print(X.shape)\n",
" X = torch.tensor([feature], dtype=torch.float).to(device)\n",
" # print(type(model))\n",
" y, hidden_state, *_ = model(X, hidden_state)\n",
" # print(hidden_state.shape)\n",
"\n",
" s = y[-1][-1].cpu()\n",
"\n",
" # proj_x proj_y v heading a d_heading\n",
" # next_step = feature\n",
"\n",
" dx, dy = s\n",
" \n",
" dx = (dx * GRID_SIZE).round() / GRID_SIZE\n",
" dy = (dy * GRID_SIZE).round() / GRID_SIZE\n",
" vx, vy = dx / dt, dy / dt\n",
"\n",
" v = np.sqrt(s[0]**2 + s[1]**2)\n",
" heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n",
" # a = (v - feature[-1][2]) / dt\n",
" ax = (vx - feature[-1][2]) / dt\n",
" ay = (vx - feature[-1][3]) / dt\n",
" # d_heading = (heading - feature[-1][5])\n",
" # print(s)\n",
" # ['x', 'y', 'vx', 'vy', 'ax', 'ay'] \n",
" x = feature[-1][0] + dx\n",
" y = feature[-1][1] + dy\n",
" if GRID_SIZE is not None:\n",
" # put points back on grid\n",
" x = (x*GRID_SIZE).round() / GRID_SIZE\n",
" y = (y*GRID_SIZE).round() / GRID_SIZE\n",
"\n",
" feature = [[x, y, vx, vy, ax, ay]]\n",
" \n",
" trajectory = np.append(trajectory, feature, axis=0)\n",
" # f = [feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]\n",
" # feature = np.append(feature, [feature], axis=0)\n",
" \n",
" # print(next_step, nxt)\n",
" # print(trajectory)\n",
" plt.plot(trajectory[:lenght,0], trajectory[:lenght,1], c='orange')\n",
" plt.plot(trajectory[lenght-1:,0], trajectory[lenght-1:,1], c='red')\n",
" plt.scatter(trajectory[lenght:,0], trajectory[lenght:,1], c='red', marker='x')"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
2024-12-27 16:12:50 +01:00
"name": "stdout",
2024-11-17 19:39:32 +01:00
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"1301\n",
"(10, 6) (10, 6)\n"
2024-11-17 19:39:32 +01:00
]
},
{
2024-12-27 16:12:50 +01:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGfCAYAAAD/BbCUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC9jklEQVR4nOz9d5gkZ3nuj9+VqzrHiTubFFYrgSWEBCIZYSSwwAKMZQw2Ahvb+BgbEHLAYGOTDD98DhySTTo/E4zBOCCMjYQlgkSSACGUV9qgnd3J07mrQ+X6/tH7vFs9aWd2Z8PMvJ/rmmt3Zqqrq7tr6r3rCfcjhGEYgsPhcDgcDucMIZ7tA+BwOBwOh7O14OKDw+FwOBzOGYWLDw6Hw+FwOGcULj44HA6Hw+GcUbj44HA4HA6Hc0bh4oPD4XA4HM4ZhYsPDofD4XA4ZxQuPjgcDofD4ZxRuPjgcDgcDodzRuHig8PhcDgczhlFXusDvve97+F//+//jZ/97GeYmZnBLbfcgpe97GVLbvu//tf/wqc+9Sn83//7f3HTTTetav9BEGB6ehrJZBKCIKz18DgcDofD4ZwFwjCEaZoYGRmBKK4c21iz+Gi327j00kvxute9Di9/+cuX3e6WW27BPffcg5GRkTXtf3p6GmNjY2s9LA6Hw+FwOOcAExMT2LZt24rbrFl8XHfddbjuuutW3GZqagpvfOMb8T//8z948YtfvKb9J5NJAL2DT6VSaz08DofD4XA4Z4Fms4mxsTG2jq/EmsXHiQiCADfeeCP+7M/+DJdccskJt7dtG7Zts+9N0wQApFIpLj44HA6Hw9lgrKZkYt0LTj/wgQ9AlmW86U1vWtX273//+5FOp9kXT7lwOBwOh7O5WVfx8bOf/Qwf+chH8LnPfW7VxaJve9vb0Gg02NfExMR6HhKHw+FwOJxzjHUVH9///vcxPz+P7du3Q5ZlyLKMI0eO4E/+5E+wc+fOJR+jaRpLsfBUC4fD4XA4m591rfm48cYbcc011/T97IUvfCFuvPFG/M7v/M56PhWHw+FwOJwNyprFR6vVwsGDB9n3hw8fxv33349cLoft27cjn8/3ba8oCoaGhrBnz55TP1oOh8PhcDgbnjWLj3vvvRfPe97z2Pc333wzAOC1r30tPve5z63bgXE4HA6Hw9mcrFl8XH311QjDcNXbj4+Pr/UpOBwOh8PhbGL4bBcOh8PhcDhnFC4+OBwOh8PhnFG4+OBwOBwOh3NG4eKDw+FwOBzOGYWLDw6Hw+FwOGeUdR8sx+FwOJsV3/cxNzcHTdOgaRoAsO6/hV2AhmFAUZQzfowczkaAiw8Oh8NZJVNTU2g2mxAEAdlsdsVtO50O8vk8FyAczhJw8cHhcDirJJvNolarQRAEGIYBQRD6hmjS/x3HgeM4qFarKBQKkCTpbB0yh3NOwsUHh8PhrJJ4PI5isQgAyGQyy07vDoIAlUoFruuiUqmgUChAFHmJHYdD8L8GDofDWSXRuo6VnJ5FUUQul4MkSfA8D51O50wcHoezYeDig8PhcFZJq9WC7/tQVfWEkQxJkmAYBoBeoSqHwzkOT7twOBzOKnAcB5OTk/B9n6VeTgQJlCAITuehcTgbDh754HA4nFXQ6XQQhiEURVm1mKBCUy4+OJx+uPjgcDicE0B1G7quIxaLodvtrupxFPngaRcOpx8uPjgcDucEmKYJAEin01AUBbZtr0pQ8LQLh7M0XHxwOBzOMoRhiFarxSIdmUwGqqoCACzLOuHjo2mXlbpjOJytBhcfHA6Hs4AwDNFutzE/P49mswkAiMViUBSFdbCspn022hFj2/bpOVgOZwPCu104HA7nGGEYotPpsJZaoBe9SCaTTHQYhoFmswnXdeF5HmR55cuooihwXRfVahWqqiKZTLK5MEuxmn1yOBsdfoZzzlnCMEQYhsu6SK4HFA6PhsXJMpueN/r96TwWztkhCAIIgrCi6Ih+7qIoQtM0WJaFbreLZDK54v5zuRxarRY6nQ4cx0GlUoGqqkgkEtB1vW9bz/MwMzODeDyObDbLzzfOpoWLD845SRAEmJ2dPduH0Ycsy9wme5Nh2zYqlUrfzyRJQiKRQCwWW3bxNwxj1eJDkiSk02kkEgm02220220290WWZQiCwARwpVKB53kwTfOEg+sI3/fhuu4iIcPhnMtw8cHZ8oiiCFEU2UJDERf6op95nod2u33CxYazcVhYBJpOp1cUHYSu6xAEAZ7nwXEcVoS6EpIkIZVKIR6PMxHieR47jmazCc/zmFhZbdTDNE10Oh3E43Gk0+lVPYbDOdtw8cE5JxFFEcPDw6ct7UL7Xe2+LctCtVpFq9VCLBbjU0o3CbquI5fLoVqtAui5mMbj8RM+jqbadjoddLvdVYkPgkRIIpGA4zgAgLm5OTiOA0EQMDQ0hEKhsKp9eZ7HOnGoJoXD2Qjw+DHnnEUQBBaRWO+vaKRjNei6DlVVWeslZ/Og6zry+TwEQUC320Wj0VjV42ix73a7CIJgzV4eoihC13W0223U63UAwPDw8KqFB9CLeoRhyM5PDmejwMUHh7NKUqkUgF6LJYXLOZsDTdOQyWQAAO12m5mKrYSqqpBlGUEQoFqtYm5uDvV6fU1upo1GA1NTUwCAYrG46pkxAOC6Lot68FQgZ6PBxQeHs0pUVYWu6yw/z9lcGIbBaiZM00S73V5xe0EQmGBpNBqwbRudTgfz8/NotVonNBXrdruYmJhAGIZIp9MYHh5mxaOrETAkkAzDgKIoq3iFHM65AxcfHM4aoOiHZVksX8/ZPMTjcRZFaDQaJ5zhQi2zhmFAlmXIsszE6fz8/LKP9zwPExMT8H0f8Xgc27dvhyAIME0TpVLphM/rOA4sy4IgCDzqwdmQcPHB4awBWZYRi8UAgEc/NinJZJIVndbr9RM6kyaTSciyDFEUoaoqstksJEmC7/uo1Wool8t9abowDJkwkWUZY2Nja27fjkY9uCEZZyPCxQeHs0aSySRzquRsTlKpFAzDQBiGqFarK0a5oumXTqcDQRAwMDCAZDIJQRDgOE5fkbIgCJAkCYqiYGhoaEW306WwbRu2bfOoB2dDw8UHZ2OjqsByXSuC0Pv9OiNJEgqFwpoXDc7GgQSFpmlMgKxUZEzpF6CXrgnDEMlkkomDaP0H1XRkMpmT8uWIzprhLd+cjQoXH5yNi6oCrtv7/0IBQt+77mkRIJzNjyAIyGazUBSFdbSs1E5L6Rff95lAWKqdOyoe1poysSwLruvyqAdnw8PFB2fjQsKDoAv9wgv+wu04nFUiiiJyuRwkSYLneahWq8t2sSxMv1iWxX5Hj3Ech6VMKFKyWqJdVolEgtv8czY0/OzlbFyWWgSWSsGcoOWRw1kJSZKQz+chiiIcx0GtVltWgCyVfoluS4WiiqKg1WqtyZjMsix4ngdRFFflwrqZoVk40YGQnI0FL5PmbGzCcPmaD/o9h3OKyLKMXC6HSqUCy7JgWdaydubJZJIJhSeeeAJhGEJVVcTjcRb18H0fjuOwybmrgYpeDcPY0lEPy7JQr9f7hNtSLsbT09MAeg62a0lRZbNZblV/Bti6ZzBn87CcwODCg7OOqKrKFqXlTMDCMIRt24jH4wiCoK9LhqIesViM+cW02+0137lvZeHRarWWrL2hSAiZtEUnYkfTX6thtdv7vr9mS33OcXjkg7PxWanbhQsQzjpCBaTLLTqNRgOdTgeGYbC0SxAEcF2XDY6jeg1ZluF5HptIeyK2enrBNE0m4OLxOBNw0QnUvu+jXC4jmUyyGT07duxAp9NZ1XNE63ZWIgxDzM3NAQBGRkZO4tVwuPjgbGxONByOCxDOOkJRh+WEgKZpaLfbbGGiCbbUEmsYBvt/PB5Ho9Fgk5JXO+jwdEx53giQ3T1NBCbo/fA8D41GA4IgQFVVFItFDA8PLysobNtGrVZDEAQQRRGZTAa6rq/qWKKRr9M1eXuzs3Xjd5yNz1J/8Memg55wOw7nJKBFZqWC02azCdu2IYoiuzun9EvUGyYWi0EURfi+v6pQ/1aPfNB7v5S/jm3bi5xkBwYGlhUFrVYLlUoFQRB
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-11-17 19:39:32 +01:00
}
],
"source": [
"\n",
2024-12-27 16:12:50 +01:00
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
"_track_id = 8701 # random.choice(track_ids)\n",
"_track_id = 3880 # random.choice(track_ids)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"# _track_id = 2780\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"for batch_idx in range(100):\n",
" _track_id = random.choice(track_ids)\n",
" plt.plot(\n",
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y'],\n",
" c='grey', alpha=.2\n",
" )\n",
"\n",
"_track_id = random.choice(track_ids)\n",
"# _track_id = 1096\n",
"_track_id = 1301\n",
"print(_track_id)\n",
"ax = plt.scatter(\n",
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y'],\n",
" marker=\"*\") \n",
"plt.plot(\n",
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y']\n",
")\n",
"\n",
"X = data.loc[_track_id,:].iloc[:][in_fields].values\n",
"# Adding randomness might be a cheat to get multiple features from current position\n",
"rnd = np.random.random_sample(X.shape) / 10\n",
"# print(rnd)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"print(X[:10].shape, (X[:10] + rnd[:10]).shape)\n",
"\n",
"# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n",
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n",
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:20][in_fields].values)\n",
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:30][in_fields].values)\n",
"predict_and_plot(model, X[:12])\n",
"predict_and_plot(model, X[:12] + rnd[:12])\n",
"predict_and_plot(model, data.loc[_track_id,:].iloc[:][in_fields].values)\n",
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2024-12-27 16:12:50 +01:00
"source": []
2024-11-17 19:39:32 +01:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2024-12-27 16:12:50 +01:00
"source": []
2024-11-17 19:39:32 +01:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-12-27 16:12:50 +01:00
"## VAE"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"# From https://github.com/CUN-bjy/lstm-vae-torch/blob/main/src/models.py (MIT)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
"from typing import Optional\n",
"\n",
"\n",
"class Encoder(nn.Module):\n",
" def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):\n",
" super(Encoder, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" self.num_layers = num_layers\n",
" self.lstm = nn.LSTM(\n",
" input_size,\n",
" hidden_size,\n",
" num_layers,\n",
" batch_first=True,\n",
" bidirectional=False,\n",
" )\n",
"\n",
" def get_hidden_state(self, batch_size, device):\n",
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" return (h, c)\n",
"\n",
" def forward(self, x, hidden_state):\n",
" # x: tensor of shape (batch_size, seq_length, hidden_size)\n",
" outputs, (hidden, cell) = self.lstm(x, hidden_state)\n",
" return outputs, (hidden, cell)\n",
"\n",
"\n",
"class Decoder(nn.Module):\n",
" def __init__(\n",
" self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2\n",
" ):\n",
" super(Decoder, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" self.output_size = output_size\n",
" self.num_layers = num_layers\n",
" self.lstm = nn.LSTM(\n",
" input_size,\n",
" hidden_size,\n",
" num_layers,\n",
" batch_first=True,\n",
" bidirectional=False,\n",
" )\n",
" self.fc = nn.Linear(hidden_size, output_size)\n",
" \n",
" def get_hidden_state(self, batch_size, device):\n",
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
" return (h, c)\n",
"\n",
" def forward(self, x, hidden):\n",
" # x: tensor of shape (batch_size, seq_length, hidden_size)\n",
" output, (hidden, cell) = self.lstm(x, hidden)\n",
" prediction = self.fc(output)\n",
" return prediction, (hidden, cell)\n",
"\n",
"\n",
"class LSTMVAE(nn.Module):\n",
" \"\"\"LSTM-based Variational Auto Encoder\"\"\"\n",
"\n",
" def __init__(\n",
" self, input_size, output_size, hidden_size, latent_size, device=torch.device(\"cuda\")\n",
" ):\n",
" \"\"\"\n",
" input_size: int, batch_size x sequence_length x input_dim\n",
" hidden_size: int, output size of LSTM AE\n",
" latent_size: int, latent z-layer size\n",
" num_lstm_layer: int, number of layers in LSTM\n",
" \"\"\"\n",
" super(LSTMVAE, self).__init__()\n",
" self.device = device\n",
"\n",
" # dimensions\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.latent_size = latent_size\n",
" self.num_layers = 1\n",
"\n",
" # lstm ae\n",
" self.lstm_enc = Encoder(\n",
" input_size=input_size, hidden_size=hidden_size, num_layers=self.num_layers\n",
" )\n",
" self.lstm_dec = Decoder(\n",
" input_size=latent_size,\n",
" output_size=output_size,\n",
" hidden_size=hidden_size,\n",
" num_layers=self.num_layers,\n",
" )\n",
"\n",
" self.fc21 = nn.Linear(self.hidden_size, self.latent_size)\n",
" self.fc22 = nn.Linear(self.hidden_size, self.latent_size)\n",
" self.fc3 = nn.Linear(self.latent_size, self.hidden_size)\n",
"\n",
" def reparametize(self, mu, logvar):\n",
" std = torch.exp(0.5 * logvar)\n",
" noise = torch.randn_like(std).to(self.device)\n",
"\n",
" z = mu + noise * std\n",
" return z\n",
"\n",
" def forward(self, x, hidden_state_encoder: Optional[tuple]=None):\n",
" batch_size, seq_len, feature_dim = x.shape\n",
"\n",
" if hidden_state_encoder is None:\n",
" hidden_state_encoder = self.lstm_enc.get_hidden_state(batch_size, self.device)\n",
"\n",
" # encode input space to hidden space\n",
" prediction, hidden_state = self.lstm_enc(x, hidden_state_encoder)\n",
" enc_h = hidden_state[0].view(batch_size, self.hidden_size).to(self.device)\n",
"\n",
" # extract latent variable z(hidden space to latent space)\n",
" mean = self.fc21(enc_h)\n",
" logvar = self.fc22(enc_h)\n",
" z = self.reparametize(mean, logvar) # batch_size x latent_size\n",
"\n",
" # initialize hidden state as inputs\n",
" h_ = self.fc3(z)\n",
2024-11-17 19:39:32 +01:00
" \n",
2024-12-27 16:12:50 +01:00
" # decode latent space to input space\n",
" z = z.repeat(1, seq_len, 1)\n",
" z = z.view(batch_size, seq_len, self.latent_size).to(self.device)\n",
"\n",
" # initialize hidden state\n",
" hidden = (h_.contiguous(), h_.contiguous())\n",
" # TODO)) the above is not the right dimensions, but this changes architecture\n",
" hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\n",
" reconstruct_output, hidden = self.lstm_dec(z, hidden)\n",
"\n",
" x_hat = reconstruct_output\n",
" \n",
" return x_hat, hidden_state, mean, logvar\n",
"\n",
" def loss_function(self, *args, **kwargs) -> dict:\n",
" \"\"\"\n",
" Computes the VAE loss function.\n",
" KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n",
" :param args:\n",
" :param kwargs:\n",
" :return:\n",
" \"\"\"\n",
" predicted = args[0]\n",
" target = args[1]\n",
" mu = args[2]\n",
" log_var = args[3]\n",
"\n",
" kld_weight = 0.00025 # Account for the minibatch samples from the dataset\n",
" recons_loss = torch.nn.functional.mse_loss(predicted, target=target)\n",
"\n",
" kld_loss = torch.mean(\n",
" -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0\n",
" )\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" loss = recons_loss + kld_weight * kld_loss\n",
" return {\n",
" \"loss\": loss,\n",
" \"Reconstruction_Loss\": recons_loss.detach(),\n",
" \"KLD\": -kld_loss.detach(),\n",
" }\n"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 303,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"vae = LSTMVAE(input_size, output_size, hidden_size, 1024, device=device)\n",
"vae.to(device)\n",
"vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 304,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-12-27 16:12:50 +01:00
"def train_vae(model, train_loader, test_loader, max_iter, learning_rate):\n",
" # optimizer\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
" ## interation setup\n",
" epochs = tqdm(range(max_iter // len(train_loader) + 1))\n",
" epochs = tqdm(range(max_iter))\n",
"\n",
" ## training\n",
" count = 0\n",
" for epoch in epochs:\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" train_iterator = tqdm(\n",
" enumerate(train_loader), total=len(train_loader), desc=\"training\"\n",
" )\n",
"\n",
" # if count > max_iter:\n",
" # print(\"max!\")\n",
" # return model\n",
"\n",
" for batch_idx, (x, targets) in train_iterator:\n",
" count += 1\n",
"\n",
" # future_data, past_data = batch_data\n",
"\n",
" ## reshape\n",
" batch_size = x.size(0)\n",
" example_size = x.size(1)\n",
" # image_size = past_data.size(2), past_data.size(3)\n",
" # past_data = (\n",
" # past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\n",
" # )\n",
" # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n",
"\n",
" y, hidden_state, mean, logvar = model(x)\n",
"\n",
" # calculate vae loss\n",
" # print(y.shape, targets.shape)\n",
" losses = model.loss_function(y, targets, mean, logvar)\n",
" mloss, recon_loss, kld_loss = (\n",
" losses[\"loss\"],\n",
" losses[\"Reconstruction_Loss\"],\n",
" losses[\"KLD\"],\n",
" )\n",
"\n",
" # Backward and optimize\n",
" optimizer.zero_grad()\n",
" mloss.mean().backward()\n",
" optimizer.step()\n",
"\n",
" train_iterator.set_postfix({\"train_loss\": float(mloss.mean())})\n",
" print(\"train_loss\", float(mloss.mean()), epoch)\n",
"\n",
" model.eval()\n",
" eval_loss = 0\n",
" test_iterator = tqdm(\n",
" enumerate(test_loader), total=len(test_loader), desc=\"testing\"\n",
" )\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" with torch.no_grad():\n",
" for batch_idx, (x, targets) in test_iterator:\n",
" # future_data, past_data = batch_data\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" ## reshape\n",
" batch_size = x.size(0)\n",
" example_size = x.size(1)\n",
" # past_data = (\n",
" # past_data.view(batch_size, example_size, -1).float().to(args.device)\n",
" # )\n",
" # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" y, hidden_state, mean, logvar = model(x)\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" # calculate vae loss\n",
" losses = model.loss_function(y, targets, mean, logvar)\n",
" mloss, recon_loss, kld_loss = (\n",
" losses[\"loss\"],\n",
" losses[\"Reconstruction_Loss\"],\n",
" losses[\"KLD\"],\n",
" )\n",
2024-11-17 19:39:32 +01:00
"\n",
2024-12-27 16:12:50 +01:00
" eval_loss += mloss.mean().item()\n",
"\n",
" test_iterator.set_postfix({\"eval_loss\": float(mloss.mean())})\n",
"\n",
" # if batch_idx == 0:\n",
" # nhw_orig = past_data[0].view(example_size, image_size[0], -1)\n",
" # nhw_recon = recon_x[0].view(example_size, image_size[0], -1)\n",
" # imshow(nhw_orig.cpu(), f\"orig{epoch}\")\n",
" # imshow(nhw_recon.cpu(), f\"recon{epoch}\")\n",
" # writer.add_images(f\"original{i}\", nchw_orig, epoch)\n",
" # writer.add_images(f\"reconstructed{i}\", nchw_recon, epoch)\n",
"\n",
" eval_loss = eval_loss / len(test_loader)\n",
" # writer.add_scalar(\"eval_loss\", float(eval_loss), epoch)\n",
" print(\"Evaluation Score : [{}]\".format(eval_loss))\n",
"\n",
" print(\"Done :-)\")\n",
" return model"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 305,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
" 0%| | 0/11 [00:00<?, ?it/s]\n",
"training: 0%| | 0/98 [00:00<?, ?it/s]\n",
" 0%| | 0/1000 [00:00<?, ?it/s]\n"
2024-11-17 19:39:32 +01:00
]
},
{
"ename": "RuntimeError",
2024-12-27 16:12:50 +01:00
"evalue": "For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors",
2024-11-17 19:39:32 +01:00
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
2024-12-27 16:12:50 +01:00
"Cell \u001b[0;32mIn[305], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain_vae\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[304], line 36\u001b[0m, in \u001b[0;36mtrain_vae\u001b[0;34m(model, train_loader, test_loader, max_iter, learning_rate)\u001b[0m\n\u001b[1;32m 29\u001b[0m example_size \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# image_size = past_data.size(2), past_data.size(3)\u001b[39;00m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;66;03m# past_data = (\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\u001b[39;00m\n\u001b[0;32m---> 36\u001b[0m y, hidden_state, mean, logvar \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# calculate vae loss\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# print(y.shape, targets.shape)\u001b[39;00m\n\u001b[1;32m 40\u001b[0m losses \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mloss_function(y, targets, mean, logvar)\n",
2024-11-17 19:39:32 +01:00
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
2024-12-27 16:12:50 +01:00
"Cell \u001b[0;32mIn[302], line 128\u001b[0m, in \u001b[0;36mLSTMVAE.forward\u001b[0;34m(self, x, hidden_state_encoder)\u001b[0m\n\u001b[1;32m 125\u001b[0m hidden \u001b[38;5;241m=\u001b[39m (h_\u001b[38;5;241m.\u001b[39mcontiguous(), h_\u001b[38;5;241m.\u001b[39mcontiguous())\n\u001b[1;32m 126\u001b[0m \u001b[38;5;66;03m# TODO)) the above is not the right dimensions, but this changes architecture\u001b[39;00m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;66;03m# hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\u001b[39;00m\n\u001b[0;32m--> 128\u001b[0m reconstruct_output, hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm_dec\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 130\u001b[0m x_hat \u001b[38;5;241m=\u001b[39m reconstruct_output\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x_hat, hidden_state, mean, logvar\n",
2024-11-17 19:39:32 +01:00
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
2024-12-27 16:12:50 +01:00
"Cell \u001b[0;32mIn[302], line 54\u001b[0m, in \u001b[0;36mDecoder.forward\u001b[0;34m(self, x, hidden)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, hidden):\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# x: tensor of shape (batch_size, seq_length, hidden_size)\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m output, (hidden, cell) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m prediction \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc(output)\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prediction, (hidden, cell)\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:755\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 752\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m):\n\u001b[1;32m 753\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor batched 3-D input, hx and cx should \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 754\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124malso be 3-D but got (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D) tensors\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 755\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 756\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 757\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n",
"\u001b[0;31mRuntimeError\u001b[0m: For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors"
2024-11-17 19:39:32 +01:00
]
}
],
"source": [
2024-12-27 16:12:50 +01:00
"train_vae(vae, loader_train, loader_test, num_epochs, learning_rate*10)"
2024-11-17 19:39:32 +01:00
]
},
{
"cell_type": "code",
2024-12-27 16:12:50 +01:00
"execution_count": 300,
2024-11-17 19:39:32 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-27 16:12:50 +01:00
"1301\n"
2024-11-17 19:39:32 +01:00
]
},
{
"data": {
2024-12-27 16:12:50 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAADF6ElEQVR4nOz9d5gkZ3nuj9+VujrH6Z6wO7NJq9VKKGcESBwLkODIiGQTDwfbB4N9CAcwBtuYY4x/AgeMbTjY4K9NsADbJGMDQiJIgHIO7Eqbdyf2TOeurq6u+Puj9nmnesLuzG7vzM7M+7muvmamu6a7OtV71xPuR/A8zwOHw+FwOBzOCiGu9g5wOBwOh8PZWHDxweFwOBwOZ0Xh4oPD4XA4HM6KwsUHh8PhcDicFYWLDw6Hw+FwOCsKFx8cDofD4XBWFC4+OBwOh8PhrChcfHA4HA6Hw1lR5NXegbm4rouJiQkkEgkIgrDau8PhcDgcDmcJeJ6HZrOJoaEhiOKJYxtnnfiYmJjA8PDwau8Gh8PhcDicU2B0dBSbN28+4TZnnfhIJBIA/J1PJpOrvDccDofD4XCWQqPRwPDwMFvHT8RZJz4o1ZJMJrn44HA4HA5njbGUkglecMrhcDgcDmdF4eKDw+FwOBzOisLFB4fD4XA4nBWFiw8Oh8PhcDgrChcfHA6Hw+FwVhQuPjgcDofD4awoXHxwOBwOh8NZUbj44HA4HA6Hs6Jw8cHhcDgcDmdFWbb4+NnPfoZbbrkFQ0NDEAQB3/nOdxbd9h3veAcEQcCnP/3p09hFDofD4XA464lli49Wq4WLL74Yn/3sZ0+43be//W088MADGBoaOuWd43A4HA6Hs/5Y9myXm2++GTfffPMJtxkfH8e73vUu/PCHP8QrXvGKU945DofD4XA464+eD5ZzXRdvectb8Hu/93u44IILTrp9p9NBp9NhfzcajV7vEofD4XBOgmVZ0HUdsiwjFout9u5w1jk9Lzj95Cc/CVmW8e53v3tJ2992221IpVLsMjw83Otd4nA4HM5JsG0brVYLhmGs9q5wNgA9FR+PPvoo/uZv/gZf/OIXlzRSFwA+/OEPo16vs8vo6Ggvd4nD4XA4HM5ZRk/Fx89//nNMT09jZGQEsixDlmUcPXoU73//+7F169YF/0dVVSSTya4Lh8PhcDic9UtPaz7e8pa34MYbb+y67mUvexne8pa34G1ve1svH4rD4XA4HM4aZdniQ9M0HDhwgP19+PBhPPHEE8hmsxgZGUEul+vaXlEUDAwMYNeuXae/txwOh8PhcNY8yxYfjzzyCF784hezv9/3vvcBAN761rfii1/8Ys92jMPhcDgczvpk2eLjhhtugOd5S97+yJEjy30IDofD4XA46xg+24XD4XA4HM6KwsUHh8PhcDicFYWLDw6Hw+FwOCsKFx8cDodzlmOaJur1+mrvBofTM3o+24XD4XA4vcN1XVQqFbiuC0EQzrgR43IaCji9Z2ZmBq7rIpfLQZbX7xLNIx8cDodzFiOKIhMcmqah2WyekcehAZ+WZZ2R++csDcdx4DjOuheBXHxwOBzOWU40GkUqlQIANJtNtFqtnt5/p9OBrusAgEgk0tP75nAWgosPDofDWQPEYjEkEgkAQL1eR7vd7sn9mqaJarUKAAiHw0zkcDhnEi4+OBwOZ42QSCQQi8UAALVaDYZhnNb9WZbVVU/iOA4ajUYvdpXDOSFcfHA4HM4aIpVKIRKJwPM8VKtVmKZ5Svdj2zYTHqFQCMlkEpZlwbbtHu8xhzMfLj44HA5njZFOpxEOh+F5HiqVyrKLRB3HQblchuM4UBQF2WwWosiXA87KwT9tHA6Hs8LYto3x8XEcPnz4lP5fEARkMhmEQiG4rotyubzkiAVt7zgOZFnmwoOzKvBPHIfD4aww5N3RaDRYl8lyEQQB2WwWiqJ0CYqTUa1WYds2JElCLpeDJEmn9PgczunAxQeHw+GsMKFQiHWuTE9Pn/L9iKKIbDYLWZZZKsV13RP+TzKZhKIoXHhwVhUuPjgcDmcVKBQKAHzfjlMtGgXQFcGwbfuk3SqKoiCfz69r90zO2Q8XHxwOh7MKRKNRxONxuK57WtEPwBcgmUwGAKDrOncp5Zz1cPHB4XA4q0Q+nwcwW4dxOoRCIYTDYQA4Ixbs7XYbrVZrSXUlHM7J4HE3zrql3W5jamoKqqqecBhXvV5HMpk847bSruvC87yunLwgCOwS/FuSJHYdp3e4rnvSmgjAH67meR4EQYAoiuxnr0kmkwiHwzAMAzMzMxgcHDzt+zMMA4ZhwDRNhEKhHu2pH1Gh+S9kdMbpPUv5fK4HuPjgrEscx8Hjjz9+0vCzbdvMpvp5z3teTw/Wp4OiKOjr6+MCpMeczlwUSZIQjUYRiUR6Wi+Rz+cxOjqKarWK/v7+0xI5siwjGo1C13U0Gg309fX1ZB89z2N1Kaqq9uQ+OSfGtm0oirLau3HG4OKDsy6RJGme8FjooB5c3IMRiF5D901n0YIgsKmVdJZNf7uuC8uyoOs6P8PsMSd6j+deL4oii1R5ngfHcdBsNtFsNhEKhZgQOd3PTDqdxtTUFLM6P13BkEgk0G63YZomDMNgqZjTwbIseJ4HURR5oeoKsd69V/iniLMuGR0dZb9nMhlccMEFa6atsNVqoV6vo9lsIhKJrPuD0EqSTCZPmIJbDNd12eTXTqcD0zRhmibq9ToikQgikcgpRwREUUQul8PU1BRKpdJpiw9JkhCLxaBpGprNJlRVPW2BRFGPsyUyuJ4RRRGu66777z0XH5x1x+HDh3Hs2DEAfkh7165da0Z4AGBhc8uy0Gw2+ZTRswBRFJnIcBwH7XYbuq7Dtm3ous6Mwii6FYlEkEgklrzo53I5zMzMoNPpoFarIZ1On9b+xuNx9hkyDOO065l4yoXTa7j44Kwr9u/fj4mJCQDA4OAgzj333FXeo+UjCAKSySTK5TJ0XUc0Gl3Xud+1hiRJiMfjiMfjME0Tuq7DMIyu9IymaXAcB+l0elEBQgKZUkG2baNer8MwDIyMjLAz32BRsiAIiEQiiEajJ9xHURQRjUahado88eE4Dur1OlKp1JJEebDeg0c+OL2Ciw/OusBxHDz33HOYmZkBAIyMjGDbtm2rvFenjqqqrAui0Wggl8ut9i5xFiAUCiEUCnXVhlA6pt1uw3VdZDKZBUPo9Xq9q7MhWFPiuu4JBcbIyAjz9VgMVVWhado8A7NyuYxOpwPLspg9+4mwbZulAXi9B6dX8E8SZ83jOA727NmDSqUCANi+fTuGh4dXea9On2QyiU6ng06n07PCQc6ZgdqjAb/jRJIkVKtVdDodlMtlZLPZeVGGQqHAxAcVHIuiiFKpBMMwutpu6XbHcaDrOsbHxxGJRE74mQiFQhAEAY7jwLbtLuFA9UQkjk50PyReFEXh3VecnsHFB2fNo2kaarUaBEHArl270N/fv9q71BNkWWaFg61Wi4uPNYSqqsjlcmzcfalUmhdlWOhz2tfXh0ceeQS2bUMQBGzZsqXrdtd1cfDgQei6jtHRUezYsWPRwkRBEBAKhZiAlWWZRTBSqRQajQZM04Trukgmk4tGNXi9B+dMsL7LaTkbglQqhV27duH8889fN8KDiMfjSCaTyGazq70rnGVCXi3BoW8nm+GiaRpyuRxkWWb1H0FEUcSWLVsgSRJ0Xcfk5OQJ748EA5mDUWdNIpFAMpmEIAioVquoVquo1Wqs3ZvwPI/9L6/34PQSLj4464JCodAzQ6WzCVEUEY/Hebh7jSJJEvr6+hAKhdjYezK1m0ur1UKn00EqlUI+n4fneRgfH5+3XSgUwubNmwEApVIJ9Xp90ccnwWCaJhMWVNDc19eHVCqFSCSCer2OarWKer3eZZ+u6zpc14Usy7zomdNTuPjgcDicMwhFG8LhMDzPQ7VanTd7JTiNNplMsm4XTdNQrVbn3Wc6nWZFyGNjY4tGVBRFYb4Rc033VFVFPp9HJpNBMplEu91GtVrtmopLbrCxWIwLYE5P4eKDw+FwzjCCICCTyTDH2maziUqlwgpOq9UqPM+DJEkIh8MIh8M
2024-11-17 19:39:32 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-12-27 16:12:50 +01:00
"# VAE Predict and plot\n",
2024-11-17 19:39:32 +01:00
"_track_id = 8701 # random.choice(track_ids)\n",
"_track_id = 3880 # random.choice(track_ids)\n",
"\n",
"# _track_id = 2780\n",
"\n",
2024-12-27 16:12:50 +01:00
"for batch_idx in range(100):\n",
2024-11-17 19:39:32 +01:00
" _track_id = random.choice(track_ids)\n",
" plt.plot(\n",
2024-12-27 16:12:50 +01:00
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y'],\n",
2024-11-17 19:39:32 +01:00
" c='grey', alpha=.2\n",
" )\n",
"\n",
"_track_id = random.choice(track_ids)\n",
2024-12-27 16:12:50 +01:00
"# _track_id = 1096\n",
"_track_id = 1301\n",
2024-11-17 19:39:32 +01:00
"print(_track_id)\n",
"ax = plt.scatter(\n",
2024-12-27 16:12:50 +01:00
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y'],\n",
2024-11-17 19:39:32 +01:00
" marker=\"*\") \n",
"plt.plot(\n",
2024-12-27 16:12:50 +01:00
" data.loc[_track_id,:]['x'],\n",
" data.loc[_track_id,:]['y']\n",
2024-11-17 19:39:32 +01:00
")\n",
"\n",
2024-12-27 16:12:50 +01:00
"# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
"predict_and_plot(vae, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n",
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n",
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:20][in_fields].values)\n",
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:30][in_fields].values)\n",
2024-11-17 19:39:32 +01:00
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2024-12-27 16:12:50 +01:00
"outputs": [
{
"ename": "AttributeError",
"evalue": "'LSTM_VAE' object has no attribute 'embed_size'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[103], line 268\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# Statistics.\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# if batch_num % 20 ==0:\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\u001b[39;00m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m test_losses\n\u001b[0;32m--> 268\u001b[0m vae \u001b[38;5;241m=\u001b[39m \u001b[43mLSTM_VAE\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 270\u001b[0m vae_loss \u001b[38;5;241m=\u001b[39m VAE_Loss()\n\u001b[1;32m 271\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(vae\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m learning_rate)\n",
"Cell \u001b[0;32mIn[103], line 73\u001b[0m, in \u001b[0;36mLSTM_VAE.__init__\u001b[0;34m(self, input_size, output_size, hidden_size, latent_size, num_layers, device)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# Decoder Part\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minit_hidden_decoder \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlatent_size, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor)\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder_lstm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLSTM(input_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_size\u001b[49m, hidden_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size, batch_first \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m, num_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers)\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_size)\n",
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n",
"\u001b[0;31mAttributeError\u001b[0m: 'LSTM_VAE' object has no attribute 'embed_size'"
]
}
],
"source": [
"# import torch\n",
"\n",
"class VAE_Loss(torch.nn.Module):\n",
" \"\"\"\n",
" Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
" \"\"\"\n",
" def __init__(self):\n",
" super(VAE_Loss, self).__init__()\n",
" self.nlloss = torch.nn.NLLLoss()\n",
" \n",
" def KL_loss (self, mu, log_var, z):\n",
" kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
" kl = kl.sum(-1) # to go from multi-dimensional z to single dimensional z : (batch_size x latent_size) ---> (batch_size) \n",
" # i.e Z = [ [z1_1, z1_2 , ...., z1_lt] ] ------> z = [ z1] \n",
" # [ [z2_1, z2_2, ....., z2_lt] ] [ z2]\n",
" # . [ . ]\n",
" # . [ . ]\n",
" # [[zn_1, zn_2, ....., zn_lt] ] [ zn]\n",
" \n",
" # lt=latent_size \n",
" kl = kl.mean()\n",
" \n",
" return kl\n",
"\n",
" def reconstruction_loss(self, x_hat_param, x):\n",
"\n",
" x = x.view(-1).contiguous()\n",
" x_hat_param = x_hat_param.view(-1, x_hat_param.size(2))\n",
"\n",
" recon = self.nlloss(x_hat_param, x)\n",
"\n",
" return recon\n",
" \n",
"\n",
" def forward(self, mu, log_var,z, x_hat_param, x):\n",
" kl_loss = self.KL_loss(mu, log_var, z)\n",
" recon_loss = self.reconstruction_loss(x_hat_param, x)\n",
"\n",
"\n",
" elbo = kl_loss + recon_loss # we use + because recon loss is a NLLoss (cross entropy) and it's negative in its own, and in the ELBO equation we have\n",
" # elbo = KL_loss - recon_loss, therefore, ELBO = KL_loss - (NLLoss) = KL_loss + NLLoss\n",
"\n",
" return elbo, kl_loss, recon_loss\n",
" \n",
"class LSTM_VAE(torch.nn.Module):\n",
" \"\"\"\n",
" Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
" \"\"\"\n",
" def __init__(self, input_size, output_size, hidden_size, latent_size, num_layers=1, device=\"cuda\"):\n",
" super(LSTM_VAE, self).__init__()\n",
"\n",
" self.device = device\n",
" \n",
" # Variables\n",
" self.num_layers = num_layers\n",
" self.lstm_factor = num_layers\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.latent_size = latent_size\n",
" self.output_size = output_size\n",
"\n",
" # X: bsz * seq_len * vocab_size \n",
" # X: bsz * seq_len * embed_size\n",
"\n",
" # Encoder Part\n",
" self.encoder_lstm = torch.nn.LSTM(input_size= input_size,hidden_size= self.hidden_size, batch_first=True, num_layers= self.num_layers)\n",
" self.mean = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n",
" self.log_variance = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n",
"\n",
" # Decoder Part\n",
" \n",
" self.hidden_decoder_linear = torch.nn.Linear(in_features= self.latent_size, out_features= self.hidden_size * self.lstm_factor)\n",
" self.decoder_lstm = torch.nn.LSTM(input_size= self.embed_size, hidden_size= self.hidden_size, batch_first = True, num_layers = self.num_layers)\n",
" self.output = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.output_size)\n",
" # self.log_softmax = torch.nn.LogSoftmax(dim=2)\n",
"\n",
" def get_hidden_state(self, batch_size):\n",
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n",
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n",
" return (h, c)\n",
"\n",
"\n",
" def encoder(self, x, hidden_state):\n",
"\n",
" # pad the packed input.\n",
"\n",
" out, (h,c) = self.encoder_lstm(x, hidden_state)\n",
" # output_encoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_encoder, batch_first=True, total_length= total_padding_length)\n",
"\n",
" # Extimate the mean and the variance of q(z|x)\n",
" mean = self.mean(h)\n",
" log_var = self.log_variance(h)\n",
" std = torch.exp(0.5 * log_var) # e^(0.5 log_var) = var^0.5\n",
" \n",
" # Generate a unit gaussian noise.\n",
" # batch_size = output_encoder.size(0)\n",
" # seq_len = output_encoder.size(1)\n",
" # noise = torch.randn(batch_size, self.latent_size).to(self.device)\n",
" noise = torch.randn(self.latent_size).to(self.device)\n",
" \n",
" z = noise * std + mean\n",
"\n",
" return z, mean, log_var, (h,c)\n",
"\n",
"\n",
" def decoder(self, z, x):\n",
"\n",
" hidden_decoder = self.hidden_decoder_linear(z)\n",
" hidden_decoder = (hidden_decoder, hidden_decoder)\n",
"\n",
" # pad the packed input.\n",
" packed_output_decoder, hidden_decoder = self.decoder_lstm(packed_x_embed,hidden_decoder) \n",
" output_decoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_decoder, batch_first=True, total_length= total_padding_length)\n",
"\n",
"\n",
" x_hat = self.output(output_decoder)\n",
" \n",
" # x_hat = self.log_softmax(x_hat)\n",
"\n",
"\n",
" return x_hat\n",
"\n",
" \n",
"\n",
" def forward(self, x, hidden_state):\n",
" \n",
" \"\"\"\n",
" x : bsz * seq_len\n",
" \n",
" hidden_encoder: ( num_lstm_layers * bsz * hidden_size, num_lstm_layers * bsz * hidden_size)\n",
"\n",
" \"\"\"\n",
" # Get Embeddings\n",
" # x_embed, maximum_padding_length = self.get_embedding(x)\n",
"\n",
" # Packing the input\n",
" # packed_x_embed = torch.nn.utils.rnn.pack_padded_sequence(input= x_embed, lengths= sentences_length, batch_first=True, enforce_sorted=False)\n",
"\n",
"\n",
" # Encoder\n",
" z, mean, log_var, hidden_encoder = self.encoder(x, maximum_padding_length, hidden_encoder)\n",
"\n",
" # Decoder\n",
" x_hat = self.decoder(z, packed_x_embed, maximum_padding_length)\n",
" \n",
" return x_hat, mean, log_var, z, hidden_encoder\n",
"\n",
" \n",
"\n",
" def inference(self, n_samples, x, z):\n",
"\n",
" # generate random z \n",
" batch_size = 1\n",
" seq_len = 1\n",
" idx_sample = []\n",
"\n",
"\n",
" hidden = self.hidden_decoder_linear(z)\n",
" hidden = (hidden, hidden)\n",
" \n",
" for i in range(n_samples):\n",
" \n",
" output,hidden = self.decoder_lstm(x, hidden)\n",
" output = self.output(output)\n",
" # output = self.log_softmax(output)\n",
" # output = output.exp()\n",
" _, s = torch.topk(output, 1)\n",
" idx_sample.append(s.item())\n",
" x = s.squeeze(0)\n",
"\n",
" w_sample = [self.dictionary.get_i2w()[str(idx)] for idx in idx_sample]\n",
" w_sample = \" \".join(w_sample)\n",
"\n",
" return w_sample\n",
"\n",
"\n",
"\n",
"def get_batch(batch):\n",
" sentences = batch[\"input\"]\n",
" target = batch[\"target\"]\n",
" sentences_length = batch[\"length\"]\n",
"\n",
" return sentences, target, sentences_length\n",
"\n",
"class Trainer:\n",
"\n",
" def __init__(self, train_loader, test_loader, model, loss, optimizer) -> None:\n",
" self.train_loader = train_loader\n",
" self.test_loader = test_loader\n",
" self.model = model\n",
" self.loss = loss\n",
" self.optimizer = optimizer\n",
" self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" self.interval = 200\n",
"\n",
"\n",
" def train(self, train_losses, epoch, batch_size, clip) -> list: \n",
" # Initialization of RNN hidden, and cell states.\n",
" states = self.model.init_hidden(batch_size) \n",
"\n",
" for batch_idx, (x, targets) in enumerate(self.train_loader):\n",
" # Get x to cuda if possible\n",
" x = x.to(device=device).squeeze(1)\n",
" targets = targets.to(device=device)\n",
"\n",
" # for batch_num, batch in enumerate(self.train_loader): # loop over the data, and jump with step = bptt.\n",
" # get the labels\n",
" source, target, source_lengths = get_batch(batch)\n",
" source = source.to(self.device)\n",
" target = target.to(self.device)\n",
"\n",
"\n",
" x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n",
"\n",
" # detach hidden states\n",
" states = states[0].detach(), states[1].detach()\n",
"\n",
" # compute the loss\n",
" mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n",
"\n",
" train_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n",
"\n",
" mloss.backward()\n",
"\n",
" torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)\n",
"\n",
" self.optimizer.step()\n",
"\n",
" self.optimizer.zero_grad()\n",
"\n",
"\n",
" if batch_num % self.interval == 0 and batch_num > 0:\n",
" \n",
" print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n",
" epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n",
"\n",
" return train_losses\n",
"\n",
" def test(self, test_losses, epoch, batch_size) -> list:\n",
"\n",
" with torch.no_grad():\n",
"\n",
" states = self.model.init_hidden(batch_size) \n",
"\n",
" for batch_num, batch in enumerate(self.test_loader): # loop over the data, and jump with step = bptt.\n",
" # get the labels\n",
" source, target, source_lengths = get_batch(batch)\n",
" source = source.to(self.device)\n",
" target = target.to(self.device)\n",
"\n",
"\n",
" x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n",
"\n",
" # detach hidden states\n",
" states = states[0].detach(), states[1].detach()\n",
"\n",
" # compute the loss\n",
" mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n",
"\n",
" test_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n",
"\n",
" # Statistics.\n",
" # if batch_num % 20 ==0:\n",
" # print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n",
" # epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n",
"\n",
" return test_losses\n",
"\n",
"vae = LSTM_VAE(input_size, output_size, hidden_size, 16, 1, device).to(device)\n",
"\n",
"vae_loss = VAE_Loss()\n",
"optimizer = torch.optim.Adam(vae.parameters(), lr= learning_rate)\n",
"\n",
"trainer = Trainer(loader_train, loader_test, vae, vae_loss, optimizer)\n",
"\n",
"\n"
]
2024-11-17 19:39:32 +01:00
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}