{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt # Visualization \n", "import torch.nn as nn\n", "import pandas_helper_calc # noqa # provides df.calc.derivative()\n", "import pandas as pd\n", "import cv2\n", "import pathlib\n", "from tqdm.autonotebook import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "FPS = 12\n", "# SRC_CSV = \"EXPERIMENTS/hofext-maskrcnn/all.txt\"\n", "# SRC_CSV = \"EXPERIMENTS/raw/generated/train/tracks.txt\"\n", "SRC_CSV = \"EXPERIMENTS/raw/hof-meter-maskrcnn2/train/tracks.txt\"\n", "SRC_CSV = \"EXPERIMENTS/20240426-hof-yolo/train/tracked.txt\"\n", "SRC_CSV = \"EXPERIMENTS/raw/hof2/train/tracked.txt\"\n", "# SRC_H = \"../DATASETS/hof/webcam20231103-2-homography.txt\"\n", "SRC_H = None\n", "CACHE_DIR = \"EXPERIMENTS/cache/hof2/\"\n", "# SMOOTHING = True # hof-yolo is already smoothed, hof2 isn't\n", "# SMOOTHING_WINDOW=3 #2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "in_fields = ['x', 'y', 'vx', 'vy', 'ax', 'ay'] #, 'dt'] (WARNING: dt column contains NaN)\n", "# out_fields = ['v', 'heading']\n", "# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n", "# an we can use simple MSE loss (I guess?)\n", "out_fields = ['dx', 'dy']\n", "SAMPLE_STEP = 5 # 1/5, for 12fps leads to effectively 12/5=2.4fps\n", "GRID_SIZE = 2 # round items on a grid of 2 points per meter (None to disable rounding)\n", "window = 8 #int(FPS*1.5 / SAMPLE_STEP)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "# Set device\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(device)\n", "\n", "# Hyperparameters\n", "input_size = len(in_fields) #in_d\n", "hidden_size = 64 # hidden_d\n", "num_layers = 1 # num_hidden\n", "output_size = len(out_fields) # out_d\n", "learning_rate = 0.005 #0.01 #0.005\n", "batch_size = 512\n", "num_epochs = 1000\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "cache_path = pathlib.Path(CACHE_DIR)\n", "cache_path.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "337124it [07:27, 753.54it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running smoother\n", "smooth x\n", "smooth y\n", "Samping 1/5, of 423077 items\n", "Done sampling kept 85961 items\n" ] } ], "source": [ "from pathlib import Path\n", "from trap.tools import load_tracks_from_csv\n", "from trap.tools import filter_short_tracks, normalise_position\n", "\n", "data= load_tracks_from_csv(Path(SRC_CSV), FPS, GRID_SIZE, SAMPLE_STEP )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# create x-norm, y_norm columns\n", "data, mu, std = normalise_position(data)\n", "data = filter_short_tracks(data, window+1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset is a bit crappy because it has different frame step: ranging from predominantly 1 and 2 to sometimes have 3 and 4 as well. This inevitabily leads to difference in speed caluclations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1660 training tracks, 415 test tracks\n" ] } ], "source": [ "track_ids = data.index.unique('track_id').to_numpy()\n", "np.random.shuffle(track_ids)\n", "test_offset_idx = int(len(track_ids) * .8)\n", "training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n", "print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "here, draw out a sample track to see if it looks alright. **unfortunately the imate isn't mapped properly**." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2254\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random\n", "# if H:\n", "# img_src = \"../DATASETS/hof/webcam20231103-2.png\"\n", "# # dst = cv2.warpPerspective(img_src,H,(2500,1920))\n", "# src_img = cv2.imread(img_src)\n", "# print(src_img.shape)\n", "# h1,w1 = src_img.shape[:2]\n", "# corners = np.float32([[0,0], [w1, 0], [0, h1], [w1, h1]])\n", "\n", "# print(corners)\n", "# corners_projected = cv2.perspectiveTransform(corners.reshape((-1,4,2)), H)[0]\n", "# print(corners_projected)\n", "# [xmin, ymin] = np.int32(corners_projected.min(axis=0).ravel() - 0.5)\n", "# [xmax, ymax] = np.int32(corners_projected.max(axis=0).ravel() + 0.5)\n", "# print(xmin, xmax, ymin, ymax)\n", "\n", "# dst = cv2.warpPerspective(src_img,H, (xmax, ymax))\n", "# def plot_track(track_id: int):\n", "# plt.gca().invert_yaxis()\n", "\n", "# plt.imshow(dst, origin='lower', extent=[xmin/100-mean_x, xmax/100-mean_x, ymin/100-mean_y, ymax/100-mean_y])\n", "# # plot scatter plot with x and y data \n", " \n", "# ax = plt.scatter(\n", "# filtered_data.loc[track_id,:]['proj_x'],\n", "# filtered_data.loc[track_id,:]['proj_y'],\n", "# marker=\"*\") \n", "# ax.axes.invert_yaxis()\n", "# plt.plot(\n", "# filtered_data.loc[track_id,:]['proj_x'],\n", "# filtered_data.loc[track_id,:]['proj_y']\n", "# )\n", "# else:\n", "def plot_track(track_id: int):\n", " ax = plt.scatter(\n", " data.loc[track_id,:]['x_norm'],\n", " data.loc[track_id,:]['y_norm'],\n", " marker=\"*\") \n", " plt.plot(\n", " data.loc[track_id,:]['x_norm'],\n", " data.loc[track_id,:]['y_norm']\n", " )\n", "\n", "# print(filtered_data.loc[track_id,:]['proj_x'])\n", "# _track_id = 2188\n", "_track_id = random.choice(track_ids)\n", "print(_track_id)\n", "plot_track(_track_id)\n", "\n", "for track_id in random.choices(track_ids, k=100):\n", " plot_track(track_id)\n", " \n", "# print(mean_x, mean_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now make the dataset:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ltwhxystatediffx_rawy_raw...vxvyaxayvaheadingd_headingx_normy_norm
track_idframe_id
1342.01393.7365720.00000067.613647121.39115113.2444082.4143392.0NaN13.52.5...6.143418e-011.389160-1.422342e+000.4532131.5189410.14209766.1430885.536579e+010.353449-1.768217
347.01393.84484912.69123886.482910156.26478613.5003842.9931562.05.013.53.0...6.143418e-011.389160-1.422342e+000.4532131.5189410.14209766.1430885.536579e+010.414443-1.574517
352.01405.27343836.67590390.329956176.46197513.5094253.6506562.05.013.53.5...2.169933e-021.577999-1.422342e+000.4532131.5781490.14209789.2121665.536579e+010.416598-1.354485
357.01421.21569876.26125391.465088181.13368213.5002214.2822792.05.013.54.5...-2.209058e-021.515896-1.050958e-01-0.1490491.516057-0.14902090.8348913.894540e+000.414404-1.143113
362.01438.374268115.36254984.298584172.14361613.4996584.7437872.05.013.54.5...-1.349331e-031.1076184.977900e-02-0.9798661.107619-0.98025090.069799-1.836220e+000.414270-0.988670
.....................................................................
503032702.01705.054635749.467887132.149004182.10504214.00000010.4952611.05.014.010.5...1.654143e-12-0.0291454.967546e-110.6571710.029145-0.657171270.0000001.644812e-080.5334920.936057
32707.01703.756025749.703112131.216670181.96191414.00000010.4996091.05.014.010.5...7.418066e-130.010435-2.189608e-120.0949920.010435-0.04490590.000000-4.320000e+020.5334920.937512
32712.01702.457415749.938337130.284337181.81878714.00000010.5001651.05.014.010.5...-4.263256e-140.001334-1.882654e-12-0.0218410.001334-0.02184190.0000001.416898e-080.5334920.937698
32717.01701.158805750.173562129.352003181.67565914.00000010.5000191.05.014.010.5...-2.984279e-14-0.0003503.069545e-14-0.0040420.000350-0.002362270.0000004.320000e+020.5334920.937649
32722.01702.384766750.754517123.435425180.94561814.00000010.4999852.05.014.010.5...0.000000e+00-0.0000827.162271e-140.0006440.000082-0.000644270.0000001.172430e-080.5334920.937638
\n", "

80035 rows × 24 columns

\n", "
" ], "text/plain": [ " l t w h x \\\n", "track_id frame_id \n", "1 342.0 1393.736572 0.000000 67.613647 121.391151 13.244408 \n", " 347.0 1393.844849 12.691238 86.482910 156.264786 13.500384 \n", " 352.0 1405.273438 36.675903 90.329956 176.461975 13.509425 \n", " 357.0 1421.215698 76.261253 91.465088 181.133682 13.500221 \n", " 362.0 1438.374268 115.362549 84.298584 172.143616 13.499658 \n", "... ... ... ... ... ... \n", "5030 32702.0 1705.054635 749.467887 132.149004 182.105042 14.000000 \n", " 32707.0 1703.756025 749.703112 131.216670 181.961914 14.000000 \n", " 32712.0 1702.457415 749.938337 130.284337 181.818787 14.000000 \n", " 32717.0 1701.158805 750.173562 129.352003 181.675659 14.000000 \n", " 32722.0 1702.384766 750.754517 123.435425 180.945618 14.000000 \n", "\n", " y state diff x_raw y_raw ... vx \\\n", "track_id frame_id ... \n", "1 342.0 2.414339 2.0 NaN 13.5 2.5 ... 6.143418e-01 \n", " 347.0 2.993156 2.0 5.0 13.5 3.0 ... 6.143418e-01 \n", " 352.0 3.650656 2.0 5.0 13.5 3.5 ... 2.169933e-02 \n", " 357.0 4.282279 2.0 5.0 13.5 4.5 ... -2.209058e-02 \n", " 362.0 4.743787 2.0 5.0 13.5 4.5 ... -1.349331e-03 \n", "... ... ... ... ... ... ... ... \n", "5030 32702.0 10.495261 1.0 5.0 14.0 10.5 ... 1.654143e-12 \n", " 32707.0 10.499609 1.0 5.0 14.0 10.5 ... 7.418066e-13 \n", " 32712.0 10.500165 1.0 5.0 14.0 10.5 ... -4.263256e-14 \n", " 32717.0 10.500019 1.0 5.0 14.0 10.5 ... -2.984279e-14 \n", " 32722.0 10.499985 2.0 5.0 14.0 10.5 ... 0.000000e+00 \n", "\n", " vy ax ay v a \\\n", "track_id frame_id \n", "1 342.0 1.389160 -1.422342e+00 0.453213 1.518941 0.142097 \n", " 347.0 1.389160 -1.422342e+00 0.453213 1.518941 0.142097 \n", " 352.0 1.577999 -1.422342e+00 0.453213 1.578149 0.142097 \n", " 357.0 1.515896 -1.050958e-01 -0.149049 1.516057 -0.149020 \n", " 362.0 1.107618 4.977900e-02 -0.979866 1.107619 -0.980250 \n", "... ... ... ... ... ... \n", "5030 32702.0 -0.029145 4.967546e-11 0.657171 0.029145 -0.657171 \n", " 32707.0 0.010435 -2.189608e-12 0.094992 0.010435 -0.044905 \n", " 32712.0 0.001334 -1.882654e-12 -0.021841 0.001334 -0.021841 \n", " 32717.0 -0.000350 3.069545e-14 -0.004042 0.000350 -0.002362 \n", " 32722.0 -0.000082 7.162271e-14 0.000644 0.000082 -0.000644 \n", "\n", " heading d_heading x_norm y_norm \n", "track_id frame_id \n", "1 342.0 66.143088 5.536579e+01 0.353449 -1.768217 \n", " 347.0 66.143088 5.536579e+01 0.414443 -1.574517 \n", " 352.0 89.212166 5.536579e+01 0.416598 -1.354485 \n", " 357.0 90.834891 3.894540e+00 0.414404 -1.143113 \n", " 362.0 90.069799 -1.836220e+00 0.414270 -0.988670 \n", "... ... ... ... ... \n", "5030 32702.0 270.000000 1.644812e-08 0.533492 0.936057 \n", " 32707.0 90.000000 -4.320000e+02 0.533492 0.937512 \n", " 32712.0 90.000000 1.416898e-08 0.533492 0.937698 \n", " 32717.0 270.000000 4.320000e+02 0.533492 0.937649 \n", " 32722.0 270.000000 1.172430e-08 0.533492 0.937638 \n", "\n", "[80035 rows x 24 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# a=filtered_data.loc[1]\n", "# min(a.index.tolist())\n", "data" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x False\n", "y False\n", "vx False\n", "vy False\n", "ax False\n", "ay False\n", "dx False\n", "dy False\n" ] } ], "source": [ "for field in in_fields + out_fields:\n", " print(field, data[field].isnull().values.any())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/1606 [00:00 The only constraint on the latent vector representation for traditional autoencoders is that latent vectors should be easily decodable back into the original image. As a result, the latent space $Z$ can become disjoint and non-continuous. Variational autoencoders try to solve this problem. [Alexander van de Kleut](https://avandekleut.github.io/vae/)\n", "\n", "For LSTM based generative VAE: https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n", "\n", "http://web.archive.org/web/20210119121802/https://towardsdatascience.com/time-series-generation-with-vae-lstm-5a6426365a1c?gi=29d8b029a386\n", "\n", "https://youtu.be/qJeaCHQ1k2w?si=30aAdqqwvz0DpR-x&t=687 VAE generate mu and sigma of a Normal distribution. Thus, they don't map the input to a single point, but a gausian distribution." ] }, { "cell_type": "code", "execution_count": 328, "metadata": {}, "outputs": [], "source": [ "class LSTMModel(nn.Module):\n", " # input_size : number of features in input at each time step\n", " # hidden_size : Number of LSTM units \n", " # num_layers : number of LSTM layers \n", " def __init__(self, input_size, hidden_size, num_layers): \n", " super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n", " # We _could_ train the h0: https://discuss.pytorch.org/t/learn-initial-hidden-state-h0-for-rnn/10013 \n", " # self.lin1 = nn.Linear(input_size, hidden_size)\n", " self.num_layers = num_layers\n", " self.hidden_size = hidden_size\n", " self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n", " self.linear = nn.Linear(hidden_size, output_size)\n", " # self.activation_v = nn.LeakyReLU(.01)\n", " # self.activation_heading = torch.remainder()\n", "\n", " \n", " def get_hidden_state(self, batch_size, device):\n", " h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " return (h, c)\n", "\n", " def forward(self, x, hidden_state): # defines forward pass of the neural network\n", " # out = self.lin1(x)\n", " \n", " out, hidden_state = self.lstm(x, hidden_state)\n", " # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n", " # print(out.shape)\n", " # TODO)) Might want to remove this below: as it might improve training\n", " # out = out[:, -1,:]\n", " # print(out.shape)\n", " out = self.linear(out)\n", " \n", " # torch.remainder(out[1], 360)\n", " # print('o',out.shape)\n", " return out, hidden_state\n", "\n", "lstm = LSTMModel(input_size, hidden_size, num_layers).to(device)\n" ] }, { "cell_type": "code", "execution_count": 329, "metadata": {}, "outputs": [], "source": [ "# model = rnn\n", "model = lstm\n" ] }, { "cell_type": "code", "execution_count": 330, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "loss_fn = nn.MSELoss()" ] }, { "cell_type": "code", "execution_count": 331, "metadata": {}, "outputs": [], "source": [ "def evaluate():\n", " # toggle evaluation mode\n", " model.eval()\n", " with torch.no_grad():\n", " batch_size, seq_len, feature_dim = X_train.shape\n", " y_pred, _ = model(\n", " X_train.to(device=device),\n", " model.get_hidden_state(batch_size, device)\n", " )\n", " train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n", " # print(y_pred)\n", "\n", " batch_size, seq_len, feature_dim = X_test.shape\n", " y_pred, _ = model(\n", " X_test.to(device=device),\n", " model.get_hidden_state(batch_size, device)\n", " )\n", " # print(loss_fn(y_pred, y_test))\n", " test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n", " print(\"Epoch ??: train RMSE %.4f, test RMSE %.4f\" % ( train_rmse, test_rmse))\n", "\n", "def load_most_recent():\n", " paths = list(cache_path.glob(f\"checkpoint-{model._get_name()}_*.pt\"))\n", " if len(paths) < 1:\n", " print('Nothing found to load')\n", " return None, None\n", " paths.sort()\n", "\n", " print(f\"Loading {paths[-1]}\")\n", " return load_cache(path=paths[-1])\n", "\n", "def load_cache(epoch=None, path=None):\n", " if path is None:\n", " if epoch is None:\n", " raise RuntimeError(\"Either path or epoch must be given\")\n", " path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n", " else:\n", " print (path.stem)\n", " epoch = int(path.stem[-5:])\n", "\n", " cached = torch.load(path)\n", " \n", " optimizer.load_state_dict(cached['optimizer_state_dict'])\n", " model.load_state_dict(cached['model_state_dict'])\n", " return epoch, cached['loss']\n", " \n", "\n", "def cache(epoch, loss):\n", " path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n", " print(f\"Cache to {path}\")\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'loss': loss,\n", " }, path)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TODO)) See [this notebook](https://www.cs.toronto.edu/~lczhang/aps360_20191/lec/w08/rnn.html) For initialization (with random or not) and the use of GRU" ] }, { "cell_type": "code", "execution_count": 332, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading EXPERIMENTS/cache/hof2/checkpoint-LSTMModel_01000.pt\n", "checkpoint-LSTMModel_01000\n", "starting from epoch 1000 (loss: 0.014368701726198196)\n", "Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "start_epoch, loss = load_most_recent()\n", "if start_epoch is None:\n", " start_epoch = 0\n", "else:\n", " print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n", " evaluate()\n", "\n", "loss_log = []\n", "# Train Network\n", "for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n", " # toggle train mode\n", " model.train()\n", " for batch_idx, (x, targets) in enumerate(loader_train):\n", " # Get x to cuda if possible\n", " x = x.to(device=device).squeeze(1)\n", " targets = targets.to(device=device)\n", "\n", " # forward\n", " scores, _ = model(\n", " x,\n", " torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device),\n", " torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device)\n", " )\n", " # print(scores)\n", " loss = loss_fn(scores, targets)\n", "\n", " # backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", "\n", " # gradient descent update step/adam step\n", " optimizer.step()\n", "\n", " loss_log.append(loss.item())\n", "\n", " if epoch % 5 != 0:\n", " continue\n", "\n", " cache(epoch, loss)\n", " evaluate()\n", "\n", "evaluate()" ] }, { "cell_type": "code", "execution_count": 333, "metadata": {}, "outputs": [], "source": [ "# print(loss)\n", "# print(len(loss_log))\n", "# plt.plot(loss_log)\n", "# plt.ylabel('Loss')\n", "# plt.xlabel('iteration')\n", "# plt.show()" ] }, { "cell_type": "code", "execution_count": 335, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([49999, 9, 2]) torch.Size([49999, 9, 2])\n" ] } ], "source": [ "model.eval()\n", "\n", "with torch.no_grad():\n", " y_pred, _ = model(X_train.to(device=device),\n", " model.get_hidden_state(X_train.shape[0], device))\n", " \n", " print(y_pred.shape, y_train.shape)\n", "# y_train, y_pred" ] }, { "cell_type": "code", "execution_count": 336, "metadata": {}, "outputs": [], "source": [ "import scipy\n", "\n", "def ceil_away_from_0(a):\n", " return np.sign(a) * np.ceil(np.abs(a))\n" ] }, { "cell_type": "code", "execution_count": 343, "metadata": {}, "outputs": [], "source": [ "def predict_and_plot(model, feature, steps = 50):\n", " lenght = feature.shape[0]\n", "\n", " dt = (1/ FPS) * SAMPLE_STEP\n", "\n", " trajectory = feature\n", "\n", " # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n", " # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n", " with torch.no_grad():\n", " # h = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n", " # c = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n", " h = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n", " c = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n", " hidden_state = (h, c)\n", " # X = torch.tensor([feature], dtype=torch.float).to(device)\n", " # y, (h, c) = model(X, h, c)\n", " for i in range(steps):\n", " # predict_f = scipy.ndimage.uniform_filter(feature)\n", " # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n", " # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n", " # bathc size of one, so feature as single item in array\n", " # print(X.shape)\n", " X = torch.tensor([feature], dtype=torch.float).to(device)\n", " # print(type(model))\n", " y, hidden_state, *_ = model(X, hidden_state)\n", " # print(hidden_state.shape)\n", "\n", " s = y[-1][-1].cpu()\n", "\n", " # proj_x proj_y v heading a d_heading\n", " # next_step = feature\n", "\n", " dx, dy = s\n", " \n", " dx = (dx * GRID_SIZE).round() / GRID_SIZE\n", " dy = (dy * GRID_SIZE).round() / GRID_SIZE\n", " vx, vy = dx / dt, dy / dt\n", "\n", " v = np.sqrt(s[0]**2 + s[1]**2)\n", " heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n", " # a = (v - feature[-1][2]) / dt\n", " ax = (vx - feature[-1][2]) / dt\n", " ay = (vx - feature[-1][3]) / dt\n", " # d_heading = (heading - feature[-1][5])\n", " # print(s)\n", " # ['x', 'y', 'vx', 'vy', 'ax', 'ay'] \n", " x = feature[-1][0] + dx\n", " y = feature[-1][1] + dy\n", " if GRID_SIZE is not None:\n", " # put points back on grid\n", " x = (x*GRID_SIZE).round() / GRID_SIZE\n", " y = (y*GRID_SIZE).round() / GRID_SIZE\n", "\n", " feature = [[x, y, vx, vy, ax, ay]]\n", " \n", " trajectory = np.append(trajectory, feature, axis=0)\n", " # f = [feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]\n", " # feature = np.append(feature, [feature], axis=0)\n", " \n", " # print(next_step, nxt)\n", " # print(trajectory)\n", " plt.plot(trajectory[:lenght,0], trajectory[:lenght,1], c='orange')\n", " plt.plot(trajectory[lenght-1:,0], trajectory[lenght-1:,1], c='red')\n", " plt.scatter(trajectory[lenght:,0], trajectory[lenght:,1], c='red', marker='x')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1301\n", "(10, 6) (10, 6)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "# print(filtered_data.loc[track_id,:]['proj_x'])\n", "_track_id = 8701 # random.choice(track_ids)\n", "_track_id = 3880 # random.choice(track_ids)\n", "\n", "# _track_id = 2780\n", "\n", "for batch_idx in range(100):\n", " _track_id = random.choice(track_ids)\n", " plt.plot(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y'],\n", " c='grey', alpha=.2\n", " )\n", "\n", "_track_id = random.choice(track_ids)\n", "# _track_id = 1096\n", "_track_id = 1301\n", "print(_track_id)\n", "ax = plt.scatter(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y'],\n", " marker=\"*\") \n", "plt.plot(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y']\n", ")\n", "\n", "X = data.loc[_track_id,:].iloc[:][in_fields].values\n", "# Adding randomness might be a cheat to get multiple features from current position\n", "rnd = np.random.random_sample(X.shape) / 10\n", "# print(rnd)\n", "\n", "print(X[:10].shape, (X[:10] + rnd[:10]).shape)\n", "\n", "# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n", "# predict_and_plot(model, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n", "# predict_and_plot(model, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n", "# predict_and_plot(model, data.loc[_track_id,:].iloc[:20][in_fields].values)\n", "# predict_and_plot(model, data.loc[_track_id,:].iloc[:30][in_fields].values)\n", "predict_and_plot(model, X[:12])\n", "predict_and_plot(model, X[:12] + rnd[:12])\n", "predict_and_plot(model, data.loc[_track_id,:].iloc[:][in_fields].values)\n", "# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n", "# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## VAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# From https://github.com/CUN-bjy/lstm-vae-torch/blob/main/src/models.py (MIT)\n", "\n", "from typing import Optional\n", "\n", "\n", "class Encoder(nn.Module):\n", " def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):\n", " super(Encoder, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.num_layers = num_layers\n", " self.lstm = nn.LSTM(\n", " input_size,\n", " hidden_size,\n", " num_layers,\n", " batch_first=True,\n", " bidirectional=False,\n", " )\n", "\n", " def get_hidden_state(self, batch_size, device):\n", " h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " return (h, c)\n", "\n", " def forward(self, x, hidden_state):\n", " # x: tensor of shape (batch_size, seq_length, hidden_size)\n", " outputs, (hidden, cell) = self.lstm(x, hidden_state)\n", " return outputs, (hidden, cell)\n", "\n", "\n", "class Decoder(nn.Module):\n", " def __init__(\n", " self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2\n", " ):\n", " super(Decoder, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.output_size = output_size\n", " self.num_layers = num_layers\n", " self.lstm = nn.LSTM(\n", " input_size,\n", " hidden_size,\n", " num_layers,\n", " batch_first=True,\n", " bidirectional=False,\n", " )\n", " self.fc = nn.Linear(hidden_size, output_size)\n", " \n", " def get_hidden_state(self, batch_size, device):\n", " h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n", " return (h, c)\n", "\n", " def forward(self, x, hidden):\n", " # x: tensor of shape (batch_size, seq_length, hidden_size)\n", " output, (hidden, cell) = self.lstm(x, hidden)\n", " prediction = self.fc(output)\n", " return prediction, (hidden, cell)\n", "\n", "\n", "class LSTMVAE(nn.Module):\n", " \"\"\"LSTM-based Variational Auto Encoder\"\"\"\n", "\n", " def __init__(\n", " self, input_size, output_size, hidden_size, latent_size, device=torch.device(\"cuda\")\n", " ):\n", " \"\"\"\n", " input_size: int, batch_size x sequence_length x input_dim\n", " hidden_size: int, output size of LSTM AE\n", " latent_size: int, latent z-layer size\n", " num_lstm_layer: int, number of layers in LSTM\n", " \"\"\"\n", " super(LSTMVAE, self).__init__()\n", " self.device = device\n", "\n", " # dimensions\n", " self.input_size = input_size\n", " self.hidden_size = hidden_size\n", " self.latent_size = latent_size\n", " self.num_layers = 1\n", "\n", " # lstm ae\n", " self.lstm_enc = Encoder(\n", " input_size=input_size, hidden_size=hidden_size, num_layers=self.num_layers\n", " )\n", " self.lstm_dec = Decoder(\n", " input_size=latent_size,\n", " output_size=output_size,\n", " hidden_size=hidden_size,\n", " num_layers=self.num_layers,\n", " )\n", "\n", " self.fc21 = nn.Linear(self.hidden_size, self.latent_size)\n", " self.fc22 = nn.Linear(self.hidden_size, self.latent_size)\n", " self.fc3 = nn.Linear(self.latent_size, self.hidden_size)\n", "\n", " def reparametize(self, mu, logvar):\n", " std = torch.exp(0.5 * logvar)\n", " noise = torch.randn_like(std).to(self.device)\n", "\n", " z = mu + noise * std\n", " return z\n", "\n", " def forward(self, x, hidden_state_encoder: Optional[tuple]=None):\n", " batch_size, seq_len, feature_dim = x.shape\n", "\n", " if hidden_state_encoder is None:\n", " hidden_state_encoder = self.lstm_enc.get_hidden_state(batch_size, self.device)\n", "\n", " # encode input space to hidden space\n", " prediction, hidden_state = self.lstm_enc(x, hidden_state_encoder)\n", " enc_h = hidden_state[0].view(batch_size, self.hidden_size).to(self.device)\n", "\n", " # extract latent variable z(hidden space to latent space)\n", " mean = self.fc21(enc_h)\n", " logvar = self.fc22(enc_h)\n", " z = self.reparametize(mean, logvar) # batch_size x latent_size\n", "\n", " # initialize hidden state as inputs\n", " h_ = self.fc3(z)\n", " \n", " # decode latent space to input space\n", " z = z.repeat(1, seq_len, 1)\n", " z = z.view(batch_size, seq_len, self.latent_size).to(self.device)\n", "\n", " # initialize hidden state\n", " hidden = (h_.contiguous(), h_.contiguous())\n", " # TODO)) the above is not the right dimensions, but this changes architecture\n", " hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\n", " reconstruct_output, hidden = self.lstm_dec(z, hidden)\n", "\n", " x_hat = reconstruct_output\n", " \n", " return x_hat, hidden_state, mean, logvar\n", "\n", " def loss_function(self, *args, **kwargs) -> dict:\n", " \"\"\"\n", " Computes the VAE loss function.\n", " KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n", " :param args:\n", " :param kwargs:\n", " :return:\n", " \"\"\"\n", " predicted = args[0]\n", " target = args[1]\n", " mu = args[2]\n", " log_var = args[3]\n", "\n", " kld_weight = 0.00025 # Account for the minibatch samples from the dataset\n", " recons_loss = torch.nn.functional.mse_loss(predicted, target=target)\n", "\n", " kld_loss = torch.mean(\n", " -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0\n", " )\n", "\n", " loss = recons_loss + kld_weight * kld_loss\n", " return {\n", " \"loss\": loss,\n", " \"Reconstruction_Loss\": recons_loss.detach(),\n", " \"KLD\": -kld_loss.detach(),\n", " }\n" ] }, { "cell_type": "code", "execution_count": 303, "metadata": {}, "outputs": [], "source": [ "vae = LSTMVAE(input_size, output_size, hidden_size, 1024, device=device)\n", "vae.to(device)\n", "vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 304, "metadata": {}, "outputs": [], "source": [ "def train_vae(model, train_loader, test_loader, max_iter, learning_rate):\n", " # optimizer\n", " optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "\n", " ## interation setup\n", " epochs = tqdm(range(max_iter // len(train_loader) + 1))\n", " epochs = tqdm(range(max_iter))\n", "\n", " ## training\n", " count = 0\n", " for epoch in epochs:\n", " model.train()\n", " optimizer.zero_grad()\n", " train_iterator = tqdm(\n", " enumerate(train_loader), total=len(train_loader), desc=\"training\"\n", " )\n", "\n", " # if count > max_iter:\n", " # print(\"max!\")\n", " # return model\n", "\n", " for batch_idx, (x, targets) in train_iterator:\n", " count += 1\n", "\n", " # future_data, past_data = batch_data\n", "\n", " ## reshape\n", " batch_size = x.size(0)\n", " example_size = x.size(1)\n", " # image_size = past_data.size(2), past_data.size(3)\n", " # past_data = (\n", " # past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\n", " # )\n", " # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n", "\n", " y, hidden_state, mean, logvar = model(x)\n", "\n", " # calculate vae loss\n", " # print(y.shape, targets.shape)\n", " losses = model.loss_function(y, targets, mean, logvar)\n", " mloss, recon_loss, kld_loss = (\n", " losses[\"loss\"],\n", " losses[\"Reconstruction_Loss\"],\n", " losses[\"KLD\"],\n", " )\n", "\n", " # Backward and optimize\n", " optimizer.zero_grad()\n", " mloss.mean().backward()\n", " optimizer.step()\n", "\n", " train_iterator.set_postfix({\"train_loss\": float(mloss.mean())})\n", " print(\"train_loss\", float(mloss.mean()), epoch)\n", "\n", " model.eval()\n", " eval_loss = 0\n", " test_iterator = tqdm(\n", " enumerate(test_loader), total=len(test_loader), desc=\"testing\"\n", " )\n", "\n", " with torch.no_grad():\n", " for batch_idx, (x, targets) in test_iterator:\n", " # future_data, past_data = batch_data\n", "\n", " ## reshape\n", " batch_size = x.size(0)\n", " example_size = x.size(1)\n", " # past_data = (\n", " # past_data.view(batch_size, example_size, -1).float().to(args.device)\n", " # )\n", " # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n", "\n", " y, hidden_state, mean, logvar = model(x)\n", "\n", " # calculate vae loss\n", " losses = model.loss_function(y, targets, mean, logvar)\n", " mloss, recon_loss, kld_loss = (\n", " losses[\"loss\"],\n", " losses[\"Reconstruction_Loss\"],\n", " losses[\"KLD\"],\n", " )\n", "\n", " eval_loss += mloss.mean().item()\n", "\n", " test_iterator.set_postfix({\"eval_loss\": float(mloss.mean())})\n", "\n", " # if batch_idx == 0:\n", " # nhw_orig = past_data[0].view(example_size, image_size[0], -1)\n", " # nhw_recon = recon_x[0].view(example_size, image_size[0], -1)\n", " # imshow(nhw_orig.cpu(), f\"orig{epoch}\")\n", " # imshow(nhw_recon.cpu(), f\"recon{epoch}\")\n", " # writer.add_images(f\"original{i}\", nchw_orig, epoch)\n", " # writer.add_images(f\"reconstructed{i}\", nchw_recon, epoch)\n", "\n", " eval_loss = eval_loss / len(test_loader)\n", " # writer.add_scalar(\"eval_loss\", float(eval_loss), epoch)\n", " print(\"Evaluation Score : [{}]\".format(eval_loss))\n", "\n", " print(\"Done :-)\")\n", " return model" ] }, { "cell_type": "code", "execution_count": 305, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/11 [00:00 1\u001b[0m \u001b[43mtrain_vae\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[304], line 36\u001b[0m, in \u001b[0;36mtrain_vae\u001b[0;34m(model, train_loader, test_loader, max_iter, learning_rate)\u001b[0m\n\u001b[1;32m 29\u001b[0m example_size \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# image_size = past_data.size(2), past_data.size(3)\u001b[39;00m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;66;03m# past_data = (\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\u001b[39;00m\n\u001b[0;32m---> 36\u001b[0m y, hidden_state, mean, logvar \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# calculate vae loss\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# print(y.shape, targets.shape)\u001b[39;00m\n\u001b[1;32m 40\u001b[0m losses \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mloss_function(y, targets, mean, logvar)\n", "File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn[302], line 128\u001b[0m, in \u001b[0;36mLSTMVAE.forward\u001b[0;34m(self, x, hidden_state_encoder)\u001b[0m\n\u001b[1;32m 125\u001b[0m hidden \u001b[38;5;241m=\u001b[39m (h_\u001b[38;5;241m.\u001b[39mcontiguous(), h_\u001b[38;5;241m.\u001b[39mcontiguous())\n\u001b[1;32m 126\u001b[0m \u001b[38;5;66;03m# TODO)) the above is not the right dimensions, but this changes architecture\u001b[39;00m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;66;03m# hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\u001b[39;00m\n\u001b[0;32m--> 128\u001b[0m reconstruct_output, hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm_dec\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 130\u001b[0m x_hat \u001b[38;5;241m=\u001b[39m reconstruct_output\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x_hat, hidden_state, mean, logvar\n", "File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn[302], line 54\u001b[0m, in \u001b[0;36mDecoder.forward\u001b[0;34m(self, x, hidden)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, hidden):\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# x: tensor of shape (batch_size, seq_length, hidden_size)\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m output, (hidden, cell) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m prediction \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc(output)\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prediction, (hidden, cell)\n", "File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:755\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 752\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m):\n\u001b[1;32m 753\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor batched 3-D input, hx and cx should \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 754\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124malso be 3-D but got (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D) tensors\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 755\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 756\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 757\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n", "\u001b[0;31mRuntimeError\u001b[0m: For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors" ] } ], "source": [ "train_vae(vae, loader_train, loader_test, num_epochs, learning_rate*10)" ] }, { "cell_type": "code", "execution_count": 300, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1301\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# VAE Predict and plot\n", "_track_id = 8701 # random.choice(track_ids)\n", "_track_id = 3880 # random.choice(track_ids)\n", "\n", "# _track_id = 2780\n", "\n", "for batch_idx in range(100):\n", " _track_id = random.choice(track_ids)\n", " plt.plot(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y'],\n", " c='grey', alpha=.2\n", " )\n", "\n", "_track_id = random.choice(track_ids)\n", "# _track_id = 1096\n", "_track_id = 1301\n", "print(_track_id)\n", "ax = plt.scatter(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y'],\n", " marker=\"*\") \n", "plt.plot(\n", " data.loc[_track_id,:]['x'],\n", " data.loc[_track_id,:]['y']\n", ")\n", "\n", "# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n", "predict_and_plot(vae, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n", "# predict_and_plot(vae, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n", "# predict_and_plot(vae, data.loc[_track_id,:].iloc[:20][in_fields].values)\n", "# predict_and_plot(vae, data.loc[_track_id,:].iloc[:30][in_fields].values)\n", "# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n", "# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'LSTM_VAE' object has no attribute 'embed_size'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[103], line 268\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# Statistics.\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# if batch_num % 20 ==0:\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\u001b[39;00m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m test_losses\n\u001b[0;32m--> 268\u001b[0m vae \u001b[38;5;241m=\u001b[39m \u001b[43mLSTM_VAE\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 270\u001b[0m vae_loss \u001b[38;5;241m=\u001b[39m VAE_Loss()\n\u001b[1;32m 271\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(vae\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m learning_rate)\n", "Cell \u001b[0;32mIn[103], line 73\u001b[0m, in \u001b[0;36mLSTM_VAE.__init__\u001b[0;34m(self, input_size, output_size, hidden_size, latent_size, num_layers, device)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# Decoder Part\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minit_hidden_decoder \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlatent_size, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor)\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder_lstm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLSTM(input_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_size\u001b[49m, hidden_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size, batch_first \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m, num_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers)\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_size)\n", "File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n", "\u001b[0;31mAttributeError\u001b[0m: 'LSTM_VAE' object has no attribute 'embed_size'" ] } ], "source": [ "# import torch\n", "\n", "class VAE_Loss(torch.nn.Module):\n", " \"\"\"\n", " Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n", " \"\"\"\n", " def __init__(self):\n", " super(VAE_Loss, self).__init__()\n", " self.nlloss = torch.nn.NLLLoss()\n", " \n", " def KL_loss (self, mu, log_var, z):\n", " kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n", " kl = kl.sum(-1) # to go from multi-dimensional z to single dimensional z : (batch_size x latent_size) ---> (batch_size) \n", " # i.e Z = [ [z1_1, z1_2 , ...., z1_lt] ] ------> z = [ z1] \n", " # [ [z2_1, z2_2, ....., z2_lt] ] [ z2]\n", " # . [ . ]\n", " # . [ . ]\n", " # [[zn_1, zn_2, ....., zn_lt] ] [ zn]\n", " \n", " # lt=latent_size \n", " kl = kl.mean()\n", " \n", " return kl\n", "\n", " def reconstruction_loss(self, x_hat_param, x):\n", "\n", " x = x.view(-1).contiguous()\n", " x_hat_param = x_hat_param.view(-1, x_hat_param.size(2))\n", "\n", " recon = self.nlloss(x_hat_param, x)\n", "\n", " return recon\n", " \n", "\n", " def forward(self, mu, log_var,z, x_hat_param, x):\n", " kl_loss = self.KL_loss(mu, log_var, z)\n", " recon_loss = self.reconstruction_loss(x_hat_param, x)\n", "\n", "\n", " elbo = kl_loss + recon_loss # we use + because recon loss is a NLLoss (cross entropy) and it's negative in its own, and in the ELBO equation we have\n", " # elbo = KL_loss - recon_loss, therefore, ELBO = KL_loss - (NLLoss) = KL_loss + NLLoss\n", "\n", " return elbo, kl_loss, recon_loss\n", " \n", "class LSTM_VAE(torch.nn.Module):\n", " \"\"\"\n", " Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n", " \"\"\"\n", " def __init__(self, input_size, output_size, hidden_size, latent_size, num_layers=1, device=\"cuda\"):\n", " super(LSTM_VAE, self).__init__()\n", "\n", " self.device = device\n", " \n", " # Variables\n", " self.num_layers = num_layers\n", " self.lstm_factor = num_layers\n", " self.input_size = input_size\n", " self.hidden_size = hidden_size\n", " self.latent_size = latent_size\n", " self.output_size = output_size\n", "\n", " # X: bsz * seq_len * vocab_size \n", " # X: bsz * seq_len * embed_size\n", "\n", " # Encoder Part\n", " self.encoder_lstm = torch.nn.LSTM(input_size= input_size,hidden_size= self.hidden_size, batch_first=True, num_layers= self.num_layers)\n", " self.mean = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n", " self.log_variance = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n", "\n", " # Decoder Part\n", " \n", " self.hidden_decoder_linear = torch.nn.Linear(in_features= self.latent_size, out_features= self.hidden_size * self.lstm_factor)\n", " self.decoder_lstm = torch.nn.LSTM(input_size= self.embed_size, hidden_size= self.hidden_size, batch_first = True, num_layers = self.num_layers)\n", " self.output = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.output_size)\n", " # self.log_softmax = torch.nn.LogSoftmax(dim=2)\n", "\n", " def get_hidden_state(self, batch_size):\n", " h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n", " c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n", " return (h, c)\n", "\n", "\n", " def encoder(self, x, hidden_state):\n", "\n", " # pad the packed input.\n", "\n", " out, (h,c) = self.encoder_lstm(x, hidden_state)\n", " # output_encoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_encoder, batch_first=True, total_length= total_padding_length)\n", "\n", " # Extimate the mean and the variance of q(z|x)\n", " mean = self.mean(h)\n", " log_var = self.log_variance(h)\n", " std = torch.exp(0.5 * log_var) # e^(0.5 log_var) = var^0.5\n", " \n", " # Generate a unit gaussian noise.\n", " # batch_size = output_encoder.size(0)\n", " # seq_len = output_encoder.size(1)\n", " # noise = torch.randn(batch_size, self.latent_size).to(self.device)\n", " noise = torch.randn(self.latent_size).to(self.device)\n", " \n", " z = noise * std + mean\n", "\n", " return z, mean, log_var, (h,c)\n", "\n", "\n", " def decoder(self, z, x):\n", "\n", " hidden_decoder = self.hidden_decoder_linear(z)\n", " hidden_decoder = (hidden_decoder, hidden_decoder)\n", "\n", " # pad the packed input.\n", " packed_output_decoder, hidden_decoder = self.decoder_lstm(packed_x_embed,hidden_decoder) \n", " output_decoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_decoder, batch_first=True, total_length= total_padding_length)\n", "\n", "\n", " x_hat = self.output(output_decoder)\n", " \n", " # x_hat = self.log_softmax(x_hat)\n", "\n", "\n", " return x_hat\n", "\n", " \n", "\n", " def forward(self, x, hidden_state):\n", " \n", " \"\"\"\n", " x : bsz * seq_len\n", " \n", " hidden_encoder: ( num_lstm_layers * bsz * hidden_size, num_lstm_layers * bsz * hidden_size)\n", "\n", " \"\"\"\n", " # Get Embeddings\n", " # x_embed, maximum_padding_length = self.get_embedding(x)\n", "\n", " # Packing the input\n", " # packed_x_embed = torch.nn.utils.rnn.pack_padded_sequence(input= x_embed, lengths= sentences_length, batch_first=True, enforce_sorted=False)\n", "\n", "\n", " # Encoder\n", " z, mean, log_var, hidden_encoder = self.encoder(x, maximum_padding_length, hidden_encoder)\n", "\n", " # Decoder\n", " x_hat = self.decoder(z, packed_x_embed, maximum_padding_length)\n", " \n", " return x_hat, mean, log_var, z, hidden_encoder\n", "\n", " \n", "\n", " def inference(self, n_samples, x, z):\n", "\n", " # generate random z \n", " batch_size = 1\n", " seq_len = 1\n", " idx_sample = []\n", "\n", "\n", " hidden = self.hidden_decoder_linear(z)\n", " hidden = (hidden, hidden)\n", " \n", " for i in range(n_samples):\n", " \n", " output,hidden = self.decoder_lstm(x, hidden)\n", " output = self.output(output)\n", " # output = self.log_softmax(output)\n", " # output = output.exp()\n", " _, s = torch.topk(output, 1)\n", " idx_sample.append(s.item())\n", " x = s.squeeze(0)\n", "\n", " w_sample = [self.dictionary.get_i2w()[str(idx)] for idx in idx_sample]\n", " w_sample = \" \".join(w_sample)\n", "\n", " return w_sample\n", "\n", "\n", "\n", "def get_batch(batch):\n", " sentences = batch[\"input\"]\n", " target = batch[\"target\"]\n", " sentences_length = batch[\"length\"]\n", "\n", " return sentences, target, sentences_length\n", "\n", "class Trainer:\n", "\n", " def __init__(self, train_loader, test_loader, model, loss, optimizer) -> None:\n", " self.train_loader = train_loader\n", " self.test_loader = test_loader\n", " self.model = model\n", " self.loss = loss\n", " self.optimizer = optimizer\n", " self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " self.interval = 200\n", "\n", "\n", " def train(self, train_losses, epoch, batch_size, clip) -> list: \n", " # Initialization of RNN hidden, and cell states.\n", " states = self.model.init_hidden(batch_size) \n", "\n", " for batch_idx, (x, targets) in enumerate(self.train_loader):\n", " # Get x to cuda if possible\n", " x = x.to(device=device).squeeze(1)\n", " targets = targets.to(device=device)\n", "\n", " # for batch_num, batch in enumerate(self.train_loader): # loop over the data, and jump with step = bptt.\n", " # get the labels\n", " source, target, source_lengths = get_batch(batch)\n", " source = source.to(self.device)\n", " target = target.to(self.device)\n", "\n", "\n", " x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n", "\n", " # detach hidden states\n", " states = states[0].detach(), states[1].detach()\n", "\n", " # compute the loss\n", " mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n", "\n", " train_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n", "\n", " mloss.backward()\n", "\n", " torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)\n", "\n", " self.optimizer.step()\n", "\n", " self.optimizer.zero_grad()\n", "\n", "\n", " if batch_num % self.interval == 0 and batch_num > 0:\n", " \n", " print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n", " epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n", "\n", " return train_losses\n", "\n", " def test(self, test_losses, epoch, batch_size) -> list:\n", "\n", " with torch.no_grad():\n", "\n", " states = self.model.init_hidden(batch_size) \n", "\n", " for batch_num, batch in enumerate(self.test_loader): # loop over the data, and jump with step = bptt.\n", " # get the labels\n", " source, target, source_lengths = get_batch(batch)\n", " source = source.to(self.device)\n", " target = target.to(self.device)\n", "\n", "\n", " x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n", "\n", " # detach hidden states\n", " states = states[0].detach(), states[1].detach()\n", "\n", " # compute the loss\n", " mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n", "\n", " test_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n", "\n", " # Statistics.\n", " # if batch_num % 20 ==0:\n", " # print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n", " # epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n", "\n", " return test_losses\n", "\n", "vae = LSTM_VAE(input_size, output_size, hidden_size, 16, 1, device).to(device)\n", "\n", "vae_loss = VAE_Loss()\n", "optimizer = torch.optim.Adam(vae.parameters(), lr= learning_rate)\n", "\n", "trainer = Trainer(loader_train, loader_test, vae, vae_loss, optimizer)\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 2 }