2301 lines
316 KiB
Text
2301 lines
316 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Goal of this notebook: implement some basic RNN/LSTM/GRU to _forecast_ trajectories based on VIRAT and/or the custom _hof_ dataset.\n",
|
|
"\n",
|
|
"Somewhat based on [test_custom_rnn](test_custom_rnn.ipynb) for the network and [test_trajectron_maps](test_trajectron_maps.ipynb) for the dataloading. And many thanks to [seq2seq-time-series-forecasting-fully-recurrent](https://github.com/maxbrenner-ai/seq2seq-time-series-forecasting-fully-recurrent/blob/main/notebook.ipynb) by maxbrenner-ai.\n",
|
|
"\n",
|
|
"TODO: Look into [TimeSeriesTransformerForPrediction](https://huggingface.co/docs/transformers/main/model_doc/time_series_transformer#transformers.TimeSeriesTransformerForPrediction) from huggingface"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from pathlib import Path\n",
|
|
"from trap.frame_emitter import Camera\n",
|
|
"from trap.utils import ImageMap\n",
|
|
"import cv2\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"from torch import optim\n",
|
|
"import torch.nn.functional as F"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Track dataset options"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"path = Path(\"EXPERIMENTS/raw/hof3/\")\n",
|
|
"calibration_path = Path(\"../DATASETS/hof3/calibration.json\")\n",
|
|
"homography_path = Path(\"../DATASETS/hof3/homography.json\")\n",
|
|
"device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"camera = Camera.from_paths(calibration_path, homography_path, 12)\n",
|
|
"\n",
|
|
"# when using a map encoder:\n",
|
|
"image_path = Path(\"../DATASETS/hof3/map-undistorted-H-2.png\")\n",
|
|
"assert image_path.exists()\n",
|
|
"\n",
|
|
"CACHE_DIR = Path(\"/tmp/cache-custom-rnn\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Network and training parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_seq_length = 36\n",
|
|
"output_seq_length = 36\n",
|
|
"\n",
|
|
"lr = 0.00005\n",
|
|
"num_epochs = 100\n",
|
|
"batch_size = 512\n",
|
|
"hidden_size = 32\n",
|
|
"num_gru_layers = 1\n",
|
|
"grad_clip = 1.0\n",
|
|
"scheduled_sampling_decay = 10\n",
|
|
"dropout = 0.\n",
|
|
"\n",
|
|
"# As opposed to point-wise (assumes Gaussian)\n",
|
|
"probabilistic = True\n",
|
|
"\n",
|
|
"use_attention = True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('At epoch 0 teacher force prob will be 0.9090909090909091',\n",
|
|
" 'At epoch 100 teacher force prob will be 0.0005014951969411607')"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# decay is used to determine \"forced teacher\" or 'freerunning' in recurrent learning\n",
|
|
"# inverse sigmoid decay from https://arxiv.org/pdf/1506.03099.pdf\n",
|
|
"import math\n",
|
|
"\n",
|
|
"\n",
|
|
"def inverse_sigmoid_decay(decay):\n",
|
|
" def compute(indx):\n",
|
|
" return decay / (decay + math.exp(indx / decay))\n",
|
|
" return compute\n",
|
|
"calc_teacher_force_prob = inverse_sigmoid_decay(scheduled_sampling_decay)\n",
|
|
"\n",
|
|
"f'At epoch 0 teacher force prob will be {calc_teacher_force_prob(0)}', \\\n",
|
|
"f'At epoch {num_epochs} teacher force prob will be {calc_teacher_force_prob(num_epochs-1)}'\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Map encoding as used in Trajectron++."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(1440, 2560, 3)\n",
|
|
"(72, 128, 3)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.image.AxesImage at 0x7f0b8c9f3520>"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"homography_matrix = np.array([\n",
|
|
" [5, 0,0],\n",
|
|
" [0, 5,0],\n",
|
|
" [0,0,1],\n",
|
|
" ]) # 100 scale\n",
|
|
"img = cv2.imread(image_path)\n",
|
|
"print(img.shape)\n",
|
|
"img = cv2.resize(img, (img.shape[1]//20, img.shape[0]//20))\n",
|
|
"\n",
|
|
"print(img.shape)\n",
|
|
"imgmap = ImageMap(img, homography_matrix, \"hof3-undistorted-H-2\")\n",
|
|
"# img = cv2.imread(image_path)\n",
|
|
"\n",
|
|
"plt.imshow(img)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from trap.tracker import TrackReader\n",
|
|
"\n",
|
|
"\n",
|
|
"reader = TrackReader(path, camera.fps, exclude_whitelisted = False, include_blacklisted=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import List\n",
|
|
"from trap.frame_emitter import Track\n",
|
|
"from trap.tracker import FinalDisplacementFilter\n",
|
|
"\n",
|
|
"# \n",
|
|
"# make sure we have all points for all tracks\n",
|
|
"tracks: List[Track] = [t.get_with_interpolated_history() for t in reader]\n",
|
|
"# t = Smoother().smooth_track(t)\n",
|
|
"track_filter = FinalDisplacementFilter(2)\n",
|
|
"tracks = track_filter.apply(tracks, camera)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def flatten_multicolumn(df : pd.DataFrame):\n",
|
|
" \"Multicolumn index to a flattended index, e.g. velocity_x\"\n",
|
|
" df.columns = ['_'.join(col).strip() for col in df.columns.values]\n",
|
|
"\n",
|
|
"def preprocess_track(track: Track, camera: Camera) -> pd.DataFrame:\n",
|
|
" df = track.to_dataframe(camera)\n",
|
|
"\n",
|
|
" flatten_multicolumn(df)\n",
|
|
" \n",
|
|
" df['dx'] = df['velocity_x'] / track.fps\n",
|
|
" df['dy'] = df['velocity_y'] / track.fps\n",
|
|
" return df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# track_dfs = [t.to_dataframe(camera) for t in tracks[:10]]\n",
|
|
"# def flatten_multicolumn(df : pd.DataFrame):\n",
|
|
"# \"Multicolumn index to a flattended index, e.g. velocity_x\"\n",
|
|
"# df.columns = ['_'.join(col).strip() for col in df.columns.values]\n",
|
|
"\n",
|
|
"# for df in track_dfs:\n",
|
|
"# flatten_multicolumn(df)\n",
|
|
" \n",
|
|
"# for df, track in zip(track_dfs, tracks):\n",
|
|
"# df['dx'] = df['velocity_x'] / track.fps\n",
|
|
"# df['dy'] = df['velocity_y'] / track.fps\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# for track in tracks:\n",
|
|
"# history = track.get_projected_history(None, camera)\n",
|
|
"# points = imgmap.to_map_points(history)\n",
|
|
"# # print(history, points)\n",
|
|
"# break\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"target_indices = [0, 1]\n",
|
|
"in_fields = [ 'dx', 'dy', 'position_x', 'position_y', 'velocity_x', 'velocity_y', 'acceleration_x', 'acceleration_y'] #, 'dt'] (WARNING: dt column contains NaN)\n",
|
|
"# out_fields = ['v', 'heading']\n",
|
|
"# velocity cannot be negative, and heading is circular (modulo), this makes it harder to optimise than a linear space, so try to use components\n",
|
|
"# an we can use simple MSE loss (I guess?)\n",
|
|
"out_fields = ['dx', 'dy']\n",
|
|
"# SAMPLE_STEP = 5 # 1/5, for 12fps leads to effectively 12/5=2.4fps\n",
|
|
"# GRID_SIZE = 2 # round items on a grid of 2 points per meter (None to disable rounding)\n",
|
|
"# window = 8 #int(FPS*1.5 / SAMPLE_STEP)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"100"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Set device\n",
|
|
"# device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
"# print(device)\n",
|
|
"\n",
|
|
"# Hyperparameters\n",
|
|
"input_size = len(in_fields) #in_d\n",
|
|
"# hidden_size = 64 # hidden_d\n",
|
|
"# num_layers = 1 # num_hidden\n",
|
|
"output_size = len(out_fields) # out_d\n",
|
|
"# learning_rate = 0.005 #0.01 #0.005\n",
|
|
"# batch_size = 512\n",
|
|
"num_epochs\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cache_path = Path(CACHE_DIR)\n",
|
|
"cache_path.mkdir(parents=True, exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# from trap.tools import load_tracks_from_csv\n",
|
|
"from trap.tools import filter_short_tracks, normalise_position\n",
|
|
"\n",
|
|
"# data= load_tracks_from_csv(Path(SRC_CSV), FPS, GRID_SIZE, SAMPLE_STEP )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# create x-norm, y_norm columns\n",
|
|
"# data, mu, std = normalise_position(data)\n",
|
|
"# data = filter_short_tracks(data, window+1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Create splits, 80% of the tracks to testing."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1903 training tracks, 476 test tracks\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"np.random.shuffle(tracks)\n",
|
|
"test_offset_idx = int(len(tracks) * .8)\n",
|
|
"\n",
|
|
"training_tracks, test_tracks = tracks[:test_offset_idx], tracks[test_offset_idx:]\n",
|
|
"# print(len(training_tracks))\n",
|
|
"print(f\"{len(training_tracks)} training tracks, {len(test_tracks)} test tracks\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # track_ids = data.index.unique('track_id').to_numpy()\n",
|
|
"# track_ids = reader.track_ids()\n",
|
|
"# np.random.shuffle(track_ids)\n",
|
|
"# test_offset_idx = int(len(track_ids) * .8)\n",
|
|
"# training_ids, test_ids = track_ids[:test_offset_idx], track_ids[test_offset_idx:]\n",
|
|
"# print(f\"{len(training_ids)} training tracks, {len(test_ids)} test tracks\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import random\n",
|
|
"# def plot_track(track_id: int):\n",
|
|
"# ax = plt.scatter(\n",
|
|
"# data.loc[track_id,:]['x_norm'],\n",
|
|
"# data.loc[track_id,:]['y_norm'],\n",
|
|
"# marker=\"*\") \n",
|
|
"# plt.plot(\n",
|
|
"# data.loc[track_id,:]['x_norm'],\n",
|
|
"# data.loc[track_id,:]['y_norm']\n",
|
|
"# )\n",
|
|
"\n",
|
|
"# # print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
|
"# # _track_id = 2188\n",
|
|
"# _track_id = random.choice(track_ids)\n",
|
|
"# print(_track_id)\n",
|
|
"# plot_track(_track_id)\n",
|
|
"\n",
|
|
"# for track_id in random.choices(track_ids, k=100):\n",
|
|
"# plot_track(track_id)\n",
|
|
" \n",
|
|
"# # print(mean_x, mean_y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now make the dataset:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# print(len(reader.get(\"24229\").history))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/1903 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 1903/1903 [03:08<00:00, 10.09it/s]\n",
|
|
"100%|██████████| 476/476 [00:40<00:00, 11.71it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from copy import deepcopy\n",
|
|
"from pandas import DataFrame\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_dataset(tracks: list[Track], input_seq_length: int, output_seq_length: int, in_fields: list[str], out_fields: list[str], camera: Camera, only_last=False, device='cpu') -> dict[str, torch.Tensor]:\n",
|
|
" encoder_X, decoder_X, decoder_y, = [], [], []\n",
|
|
" # factor = SAMPLE_STEP if SAMPLE_STEP is not None else 1\n",
|
|
" for track in tqdm(tracks):\n",
|
|
" # track = reader.get(track_id)\n",
|
|
" if len(track.history) < 2:\n",
|
|
" print(track.track_id, \"too short\")\n",
|
|
" df: DataFrame = preprocess_track(track, camera)\n",
|
|
" # df = data.loc[track_id]\n",
|
|
" # print(df)\n",
|
|
" # start_frame = min(df.index.tolist())\n",
|
|
" for timestep in range(df.shape[0] - (input_seq_length + output_seq_length) + 1):\n",
|
|
"\n",
|
|
" # enc_inputs: (input seq len, num features)\n",
|
|
" # print(df[timestep:timestep+input_seq_length][['velocity_x', 'velocity_y']])\n",
|
|
" enc_inputs_at_t = deepcopy(df[timestep : timestep + input_seq_length][in_fields])\n",
|
|
" dec_at_t = deepcopy(df[timestep + input_seq_length - 1 : timestep + input_seq_length + output_seq_length])\n",
|
|
" # dec_targets: (output seq len, num features)\n",
|
|
" dec_inputs_at_t = deepcopy(dec_at_t[:-1][in_fields])\n",
|
|
" # dec_targets: (output seq len, num targets)\n",
|
|
" dec_targets_at_t = deepcopy(dec_at_t[1:][out_fields])\n",
|
|
" \n",
|
|
" # # for step in range(len(df)-window-1):\n",
|
|
" # i = int(start_frame) + (step*factor)\n",
|
|
" # # print(step, int(start_frame), i)\n",
|
|
" # feature = df.loc[i:i+(window*factor)][in_fields]\n",
|
|
" # # target = df.loc[i+1:i+window+1][out_fields]\n",
|
|
" # # print(i, window*factor, factor, i+window*factor+factor, df['idx_in_track'])\n",
|
|
" # # print(i+window*factor+factor)\n",
|
|
" # if only_last:\n",
|
|
" # target = df.loc[i+window*factor+factor][out_fields]\n",
|
|
" # else:\n",
|
|
" # target = df.loc[i+factor:i+window*factor+factor][out_fields]\n",
|
|
"\n",
|
|
" encoder_X.append(enc_inputs_at_t.values)\n",
|
|
" decoder_X.append(dec_inputs_at_t.values)\n",
|
|
" decoder_y.append(dec_targets_at_t.values)\n",
|
|
" \n",
|
|
" return {'enc_inputs': torch.tensor(np.array(encoder_X), device=device, dtype=torch.float), \n",
|
|
" 'dec_inputs': torch.tensor(np.array(decoder_X), device=device, dtype=torch.float), \n",
|
|
" 'dec_outputs': torch.tensor(np.array(decoder_y), device=device, dtype=torch.float)}\n",
|
|
"\n",
|
|
"train_data = create_dataset(training_tracks, input_seq_length, output_seq_length, in_fields, out_fields, camera, False, device)\n",
|
|
"test_data = create_dataset(test_tracks, input_seq_length, output_seq_length, in_fields, out_fields, camera, False, device)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# X_train, y_train = X_train.to(device=device), y_train.to(device=device)\n",
|
|
"# X_test, y_test = X_test.to(device=device), y_test.to(device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# from torch.utils.data import TensorDataset, DataLoader\n",
|
|
"# dataset_train = TensorDataset(X_train, y_train)\n",
|
|
"# loader_train = DataLoader(dataset_train, shuffle=True, batch_size=batch_size)\n",
|
|
"# dataset_test = TensorDataset(X_test, y_test)\n",
|
|
"# loader_test = DataLoader(dataset_test, shuffle=False, batch_size=batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Model give output for all timesteps, this should improve training. But we use only the last timestep for the prediction process"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Seq2seq encoder/decoder with attention"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Encoder(nn.Module):\n",
|
|
" def __init__(self, input_size, hidden_size, num_gru_layers, dropout_p=0.1):\n",
|
|
" super().__init__()\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
"\n",
|
|
" # self.embedding = nn.Embedding(input_size, hidden_size)\n",
|
|
" self.gru = nn.GRU(input_size, hidden_size, num_gru_layers, batch_first=True)\n",
|
|
" # self.dropout = nn.Dropout(dropout_p) # TODO)) How to bring this back, see \n",
|
|
"\n",
|
|
" def forward(self, inputs):\n",
|
|
" # inputs: (batch size, input seq len, num enc features)\n",
|
|
" # embedded = self.dropout(self.embedding(input))\n",
|
|
"\n",
|
|
" # output: (batch size, input seq len, hidden size)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" output, hidden = self.gru(inputs)\n",
|
|
" return output, hidden\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Decoder superclass whose forward is called by Seq2Seq but other methods filled out by subclasses\n",
|
|
"import random\n",
|
|
"\n",
|
|
"\n",
|
|
"class DecoderBase(nn.Module):\n",
|
|
" def __init__(self, device, dec_target_size, target_indices, dist_size, probabilistic):\n",
|
|
" super().__init__()\n",
|
|
" self.device = device\n",
|
|
" self.target_indices = target_indices # indices of desired output in the training input vector, used for force teaching\n",
|
|
" self.target_size = dec_target_size\n",
|
|
" self.dist_size = dist_size\n",
|
|
" self.probabilistic = probabilistic\n",
|
|
" \n",
|
|
" # Have to run one step at a time unlike with the encoder since sometimes not teacher forcing\n",
|
|
" def run_single_recurrent_step(self, inputs, hidden, enc_outputs):\n",
|
|
" raise NotImplementedError()\n",
|
|
" \n",
|
|
" def forward(self, inputs, hidden, enc_outputs, teacher_force_prob=None):\n",
|
|
" # inputs: (batch size, output seq length, num dec features)\n",
|
|
" # hidden: (num gru layers, batch size, hidden dim), ie the last hidden state\n",
|
|
" # enc_outputs: (batch size, input seq len, hidden size)\n",
|
|
" \n",
|
|
" batch_size, dec_output_seq_length, _ = inputs.shape\n",
|
|
" \n",
|
|
" # Store decoder outputs\n",
|
|
" # outputs: (batch size, output seq len, num targets, num dist params)\n",
|
|
" outputs = torch.zeros(batch_size, dec_output_seq_length, self.target_size, self.dist_size, dtype=torch.float).to(self.device)\n",
|
|
"\n",
|
|
" # curr_input: (batch size, 1, num dec features)\n",
|
|
" curr_input = inputs[:, 0:1, :]\n",
|
|
" \n",
|
|
" for t in range(dec_output_seq_length):\n",
|
|
" # dec_output: (batch size, 1, num targets, num dist params)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" dec_output, hidden = self.run_single_recurrent_step(curr_input, hidden, enc_outputs)\n",
|
|
" # Save prediction\n",
|
|
" outputs[:, t:t+1, :, :] = dec_output\n",
|
|
" # dec_output: (batch size, 1, num targets)\n",
|
|
" dec_output = Seq2Seq.sample_from_output(dec_output)\n",
|
|
" \n",
|
|
" # If teacher forcing, use target from this timestep as next input o.w. use prediction\n",
|
|
" teacher_force: bool = random.random() < teacher_force_prob if teacher_force_prob is not None else False\n",
|
|
" \n",
|
|
" curr_input = inputs[:, t:t+1, :].clone()\n",
|
|
" if not teacher_force:\n",
|
|
" curr_input[:, :, self.target_indices] = dec_output\n",
|
|
" return outputs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def layer_init(layer, w_scale=1.0):\n",
|
|
" nn.init.kaiming_uniform_(layer.weight.data)\n",
|
|
" layer.weight.data.mul_(w_scale)\n",
|
|
" nn.init.constant_(layer.bias.data, 0.)\n",
|
|
" return layer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"class DecoderVanilla(DecoderBase):\n",
|
|
" def __init__(self, dec_feature_size, dec_target_size, hidden_size, \n",
|
|
" num_gru_layers, target_indices, dropout, dist_size,\n",
|
|
" probabilistic, device):\n",
|
|
" super().__init__(device, dec_target_size, target_indices, dist_size, probabilistic)\n",
|
|
" self.gru = nn.GRU(dec_feature_size, hidden_size, num_gru_layers, batch_first=True, dropout=dropout)\n",
|
|
" self.out = layer_init(nn.Linear(hidden_size + dec_feature_size, dec_target_size * dist_size))\n",
|
|
" \n",
|
|
" def run_single_recurrent_step(self, inputs, hidden, enc_outputs):\n",
|
|
" # inputs: (batch size, 1, num dec features)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" \n",
|
|
" output, hidden = self.gru(inputs, hidden)\n",
|
|
" output = self.out(torch.cat((output, inputs), dim=2))\n",
|
|
" output = output.reshape(output.shape[0], output.shape[1], self.target_size, self.dist_size)\n",
|
|
" \n",
|
|
" # output: (batch size, 1, num targets, num dist params)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" return output, hidden\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Attention(nn.Module):\n",
|
|
" def __init__(self, hidden_size, num_gru_layers):\n",
|
|
" super().__init__()\n",
|
|
" # NOTE: the hidden size for the output of attn (and input of v) can actually be any number \n",
|
|
" # Also, using two layers allows for a non-linear act func inbetween\n",
|
|
" self.attn = nn.Linear(2 * hidden_size, hidden_size)\n",
|
|
" self.v = nn.Linear(hidden_size, 1, bias=False)\n",
|
|
" \n",
|
|
" def forward(self, decoder_hidden_final_layer, encoder_outputs):\n",
|
|
" # decoder_hidden_final_layer: (batch size, hidden size)\n",
|
|
" # encoder_outputs: (batch size, input seq len, hidden size)\n",
|
|
" \n",
|
|
" # Repeat decoder hidden state input seq len times\n",
|
|
" hidden = decoder_hidden_final_layer.unsqueeze(1).repeat(1, encoder_outputs.shape[1], 1)\n",
|
|
" \n",
|
|
" # Compare decoder hidden state with each encoder output using a learnable tanh layer\n",
|
|
" energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))\n",
|
|
" \n",
|
|
" # Then compress into single values for each comparison (energy)\n",
|
|
" attention = self.v(energy).squeeze(2)\n",
|
|
" \n",
|
|
" # Then softmax so the weightings add up to 1\n",
|
|
" weightings = F.softmax(attention, dim=1)\n",
|
|
" \n",
|
|
" # weightings: (batch size, input seq len)\n",
|
|
" return weightings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DecoderWithAttention(DecoderBase):\n",
|
|
" def __init__(self, dec_feature_size, dec_target_size, hidden_size, \n",
|
|
" num_gru_layers, target_indices, dropout, dist_size,\n",
|
|
" probabilistic, device):\n",
|
|
" super().__init__(device, dec_target_size, target_indices, dist_size, probabilistic)\n",
|
|
" self.attention_model = Attention(hidden_size, num_gru_layers)\n",
|
|
" # GRU takes previous timestep target and weighted sum of encoder hidden states\n",
|
|
" self.gru = nn.GRU(dec_feature_size + hidden_size, hidden_size, num_gru_layers, batch_first=True, dropout=dropout)\n",
|
|
" # Output layer takes decoder hidden state output, weighted sum and decoder input\n",
|
|
" # NOTE: Feeding decoder input into the output layer essentially acts as a skip connection\n",
|
|
" self.out = layer_init(nn.Linear(hidden_size + hidden_size + dec_feature_size, dec_target_size * dist_size))\n",
|
|
"\n",
|
|
" def run_single_recurrent_step(self, inputs, hidden, enc_outputs):\n",
|
|
" # inputs: (batch size, 1, num dec features)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" # enc_outputs: (batch size, input seq len, hidden size)\n",
|
|
" \n",
|
|
" # Get attention weightings\n",
|
|
" # weightings: (batch size, input seq len)\n",
|
|
" weightings = self.attention_model(hidden[-1], enc_outputs)\n",
|
|
" \n",
|
|
" # Then compute weighted sum\n",
|
|
" # weighted_sum: (batch size, 1, hidden size)\n",
|
|
" weighted_sum = torch.bmm(weightings.unsqueeze(1), enc_outputs)\n",
|
|
" \n",
|
|
" # Then input into GRU\n",
|
|
" # gru inputs: (batch size, 1, num dec features + hidden size)\n",
|
|
" # output: (batch size, 1, hidden size)\n",
|
|
" output, hidden = self.gru(torch.cat((inputs, weighted_sum), dim=2), hidden)\n",
|
|
" \n",
|
|
" # Get prediction\n",
|
|
" # out input: (batch size, 1, hidden size + hidden size + num targets)\n",
|
|
" output = self.out(torch.cat((output, weighted_sum, inputs), dim=2))\n",
|
|
" output = output.reshape(output.shape[0], output.shape[1], self.target_size, self.dist_size)\n",
|
|
" \n",
|
|
" # output: (batch size, 1, num targets, num dist params)\n",
|
|
" # hidden: (num gru layers, batch size, hidden size)\n",
|
|
" return output, hidden\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Seq2Seq(nn.Module):\n",
|
|
" def __init__(self, encoder: Encoder, decoder: DecoderBase, lr, grad_clip, probabilistic):\n",
|
|
" super().__init__()\n",
|
|
" self.encoder = encoder\n",
|
|
" self.decoder = decoder\n",
|
|
"\n",
|
|
" self.opt = torch.optim.Adam(self.parameters(), lr)\n",
|
|
" self.loss_func = nn.GaussianNLLLoss() if probabilistic else nn.L1Loss()\n",
|
|
" self.grad_clip = grad_clip\n",
|
|
"\n",
|
|
" self.probabilistic = probabilistic\n",
|
|
" \n",
|
|
" @staticmethod\n",
|
|
" def compute_smape(prediction, target):\n",
|
|
" return torch.mean(torch.abs(prediction - target) / ((torch.abs(target) + torch.abs(prediction)) / 2. + 1e-8)) * 100.\n",
|
|
" \n",
|
|
" @staticmethod\n",
|
|
" def get_dist_params(output):\n",
|
|
" mu = output[:, :, :, 0]\n",
|
|
" # softplus to constrain to positive\n",
|
|
" sigma = F.softplus(output[:, :, :, 1])\n",
|
|
" return mu, sigma\n",
|
|
" \n",
|
|
" @staticmethod\n",
|
|
" def sample_from_output(output):\n",
|
|
" # in - output: (batch size, dec seq len, num targets, num dist params)\n",
|
|
" # out - output: (batch size, dec seq len, num targets)\n",
|
|
" if output.shape[-1] > 1: # probabilistic can be assumed\n",
|
|
" mu, sigma = Seq2Seq.get_dist_params(output)\n",
|
|
" # sigma = torch.tensor([0,0], device='cuda')\n",
|
|
" return torch.normal(mu, sigma)\n",
|
|
" # No sample just reshape if pointwise\n",
|
|
" return output.squeeze(-1)\n",
|
|
" \n",
|
|
" def forward(self, enc_inputs, dec_inputs, teacher_force_prob=None):\n",
|
|
" # enc_inputs: (batch size, input seq length, num enc features)\n",
|
|
" # dec_inputs: (batch size, output seq length, num dec features)\n",
|
|
" \n",
|
|
" # enc_outputs: (batch size, input seq len, hidden size)\n",
|
|
" # hidden: (num gru layers, batch size, hidden dim), ie the last hidden state\n",
|
|
" enc_outputs, hidden = self.encoder(enc_inputs)\n",
|
|
" \n",
|
|
" # outputs: (batch size, output seq len, num targets, num dist params)\n",
|
|
" outputs = self.decoder(dec_inputs, hidden, enc_outputs, teacher_force_prob)\n",
|
|
" \n",
|
|
" return outputs\n",
|
|
"\n",
|
|
" def compute_loss(self, prediction, target, override_func=None):\n",
|
|
" # prediction: (batch size, dec seq len, num targets, num dist params)\n",
|
|
" # target: (batch size, dec seq len, num targets)\n",
|
|
" if self.probabilistic:\n",
|
|
" mu, sigma = Seq2Seq.get_dist_params(prediction)\n",
|
|
" var = sigma ** 2\n",
|
|
" loss = self.loss_func(mu, target, var)\n",
|
|
" else:\n",
|
|
" loss = self.loss_func(prediction.squeeze(-1), target)\n",
|
|
" return loss if self.training else loss.item()\n",
|
|
" \n",
|
|
" def optimize(self, prediction, target):\n",
|
|
" # prediction & target: (batch size, seq len, output dim)\n",
|
|
" self.opt.zero_grad()\n",
|
|
" loss = self.compute_loss(prediction, target)\n",
|
|
" loss.backward()\n",
|
|
" if self.grad_clip is not None:\n",
|
|
" torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)\n",
|
|
" self.opt.step()\n",
|
|
" return loss.item()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## run training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# New generator every epoch\n",
|
|
"def batch_generator(data, batch_size):\n",
|
|
" enc_inputs, dec_inputs, dec_targets, scalers = \\\n",
|
|
" data['enc_inputs'], data['dec_inputs'], data['dec_outputs'], None #, data['scalers']\n",
|
|
" indices = torch.randperm(enc_inputs.shape[0])\n",
|
|
" for i in range(0, len(indices), batch_size):\n",
|
|
" batch_indices = indices[i : i + batch_size]\n",
|
|
" batch_enc_inputs = enc_inputs[batch_indices]\n",
|
|
" batch_dec_inputs = dec_inputs[batch_indices]\n",
|
|
" batch_dec_targets = dec_targets[batch_indices]\n",
|
|
" batch_scalers = None\n",
|
|
" \n",
|
|
" # No remainder\n",
|
|
" if batch_enc_inputs.shape[0] < batch_size:\n",
|
|
" break\n",
|
|
" yield batch_enc_inputs, batch_dec_inputs, batch_dec_targets, batch_scalers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train(model, train_data, batch_size, teacher_force_prob):\n",
|
|
" model.train()\n",
|
|
" \n",
|
|
" epoch_loss = 0.\n",
|
|
" num_batches = 0\n",
|
|
" \n",
|
|
" for batch_enc_inputs, batch_dec_inputs, batch_dec_targets, _ in batch_generator(train_data, batch_size):\n",
|
|
" output = model(batch_enc_inputs, batch_dec_inputs, teacher_force_prob)\n",
|
|
" loss = model.optimize(output, batch_dec_targets)\n",
|
|
" \n",
|
|
" epoch_loss += loss\n",
|
|
" num_batches += 1\n",
|
|
" return epoch_loss / num_batches"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate(model, val_data, batch_size):\n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" epoch_loss = 0.\n",
|
|
" num_batches = 0\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" for batch_enc_inputs, batch_dec_inputs, batch_dec_targets, _ in batch_generator(val_data, batch_size):\n",
|
|
" output = model(batch_enc_inputs, batch_dec_inputs)\n",
|
|
" loss = model.compute_loss(output, batch_dec_targets)\n",
|
|
"\n",
|
|
" epoch_loss += loss\n",
|
|
" num_batches += 1\n",
|
|
" \n",
|
|
" return epoch_loss / num_batches\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dist_size = 2 if probabilistic else 1\n",
|
|
"enc_feature_size = train_data['enc_inputs'].shape[-1]\n",
|
|
"dec_feature_size = train_data['dec_inputs'].shape[-1]\n",
|
|
"dec_target_size = train_data['dec_outputs'].shape[-1]\n",
|
|
"\n",
|
|
"encoder = Encoder(enc_feature_size, hidden_size, num_gru_layers, dropout)\n",
|
|
"decoder_args = (dec_feature_size, dec_target_size, hidden_size, num_gru_layers, target_indices, dropout, dist_size, probabilistic, device)\n",
|
|
"decoder = DecoderWithAttention(*decoder_args) if use_attention else DecoderVanilla(*decoder_args)\n",
|
|
"seq2seq = Seq2Seq(encoder, decoder, lr, grad_clip, probabilistic).to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_cache(model: Seq2Seq, epoch=None, path=None):\n",
|
|
" if path is None:\n",
|
|
" if epoch is None:\n",
|
|
" raise RuntimeError(\"Either path or epoch must be given\")\n",
|
|
" path = cache_path / f\"checkpoint-{model.__class__.__name__}_{epoch:05d}.pt\"\n",
|
|
" else:\n",
|
|
" print (path.stem)\n",
|
|
" epoch = int(path.stem[-5:])\n",
|
|
"\n",
|
|
" cached = torch.load(path)\n",
|
|
" \n",
|
|
" # optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
|
|
" model.load_state_dict(cached['model_state_dict'])\n",
|
|
" return epoch, cached['train_loss']\n",
|
|
" \n",
|
|
"\n",
|
|
"def cache(model: Seq2Seq, epoch, loss):\n",
|
|
" path = cache_path / f\"checkpoint-{model.__class__.__name__}_{epoch:05d}.pt\"\n",
|
|
" print(f\"Cache to {path}\")\n",
|
|
" torch.save({\n",
|
|
" 'epoch': epoch,\n",
|
|
" 'model_state_dict': model.state_dict(),\n",
|
|
" 'train_loss': loss,\n",
|
|
" }, path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"load /tmp/cache-custom-rnn/checkpoint-Seq2Seq_00100.pt\n",
|
|
"checkpoint-Seq2Seq_00100\n",
|
|
"Loaded epoch=100 with train_loss=-3.9480560053240117\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import time\n",
|
|
"\n",
|
|
"CACHE_FILE = Path(\"/tmp/cache-custom-rnn/checkpoint-Seq2Seq_00100.pt\")\n",
|
|
"if CACHE_FILE.exists():\n",
|
|
" print(f\"load {CACHE_FILE}\")\n",
|
|
" best_model = deepcopy(seq2seq)\n",
|
|
" epoch, train_loss = load_cache(best_model, path=CACHE_FILE)\n",
|
|
" print(f\"Loaded {epoch=} with {train_loss=}\")\n",
|
|
"else:\n",
|
|
"\n",
|
|
" val_data = test_data\n",
|
|
"\n",
|
|
" best_val, best_model = float('inf'), None\n",
|
|
" for epoch in range(num_epochs):\n",
|
|
" start_t = time.time()\n",
|
|
" teacher_force_prob = calc_teacher_force_prob(epoch)\n",
|
|
" train_loss = train(seq2seq, train_data, batch_size, teacher_force_prob)\n",
|
|
" val_loss = evaluate(seq2seq, val_data, batch_size)\n",
|
|
"\n",
|
|
" new_best_val = False\n",
|
|
" if val_loss < best_val:\n",
|
|
" new_best_val = True\n",
|
|
" best_val = val_loss\n",
|
|
" best_model = deepcopy(seq2seq)\n",
|
|
" print(f'Epoch {epoch+1} => Train loss: {train_loss:.5f},',\n",
|
|
" f'Val: {val_loss:.5f},',\n",
|
|
" f'Teach: {teacher_force_prob:.2f},',\n",
|
|
" f'Took {(time.time() - start_t):.1f} s{\" (NEW BEST)\" if new_best_val else \"\"}')\n",
|
|
" \n",
|
|
" \n",
|
|
" cache(best_model, 100, train_loss)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"data_to_eval = test_data\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/ruben/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:950: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at ../aten/src/ATen/native/cudnn/RNN.cpp:968.)\n",
|
|
" result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([1, 50, 36, 2])\n",
|
|
"[36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1000x500 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Visualize\n",
|
|
"target_to_vis = 0\n",
|
|
"num_vis = 1\n",
|
|
"num_rollouts = 50 if probabilistic else 1\n",
|
|
"\n",
|
|
"best_model.eval()\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" \n",
|
|
" batch_enc_inputs, batch_dec_inputs, batch_dec_targets, scalers = next(batch_generator(data_to_eval, num_vis))\n",
|
|
" \n",
|
|
" outputs = []\n",
|
|
" for r in range(num_rollouts):\n",
|
|
" mo = best_model(batch_enc_inputs, batch_dec_inputs)\n",
|
|
" output = Seq2Seq.sample_from_output(mo)\n",
|
|
" outputs.append(output)\n",
|
|
" outputs = torch.stack(outputs, dim=1)\n",
|
|
" print(outputs.shape)\n",
|
|
"\n",
|
|
"for indx in range(batch_enc_inputs.shape[0]):\n",
|
|
" # scaler = scalers[indx]\n",
|
|
" sample_enc_inputs, sample_dec_inputs, sample_dec_targets = \\\n",
|
|
" (batch_enc_inputs[indx])[:, target_to_vis].cpu().numpy().tolist(),\\\n",
|
|
" (batch_dec_inputs[indx])[:, target_to_vis].cpu().numpy().tolist(), \\\n",
|
|
" (batch_dec_targets[indx])[:, target_to_vis].cpu().numpy().tolist()\n",
|
|
" output_rollouts = []\n",
|
|
" for output_rollout in outputs[indx]:\n",
|
|
" output_rollouts.append((output_rollout)[:, target_to_vis].cpu().numpy().tolist())\n",
|
|
" output_rollouts = np.array(output_rollouts)\n",
|
|
"\n",
|
|
" plt.figure(figsize=(10,5))\n",
|
|
" x = list(range(len(sample_enc_inputs) + len(sample_dec_targets)))\n",
|
|
" # Plot inputs\n",
|
|
" plt.plot(x, sample_enc_inputs + sample_dec_targets)\n",
|
|
" # Plot median\n",
|
|
" output_x = list(range(len(sample_enc_inputs), len(x)))\n",
|
|
" print(output_x)\n",
|
|
" plt.plot(output_x, np.median(output_rollouts, axis=0))\n",
|
|
" # Plot quantiles\n",
|
|
" plt.fill_between(\n",
|
|
" output_x,\n",
|
|
" np.quantile(output_rollouts, 0.05, axis=0), \n",
|
|
" np.quantile(output_rollouts, 0.95, axis=0), \n",
|
|
" alpha=0.3, \n",
|
|
" interpolate=True\n",
|
|
" )\n",
|
|
" plt.gca().set_axis_off()\n",
|
|
" plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## RNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class SimpleRnn(nn.Module):\n",
|
|
" def __init__(self, in_d=2, out_d=2, hidden_d=4, num_hidden=1):\n",
|
|
" super(SimpleRnn, self).__init__()\n",
|
|
" self.rnn = nn.RNN(input_size=in_d, hidden_size=hidden_d, num_layers=num_hidden)\n",
|
|
" self.fc = nn.Linear(hidden_d, out_d)\n",
|
|
"\n",
|
|
" def forward(self, x, h0):\n",
|
|
" r, h = self.rnn(x, h0)\n",
|
|
" # r = r[:, -1,:]\n",
|
|
" y = self.fc(r) # no activation on the output\n",
|
|
" return y, h\n",
|
|
"rnn = SimpleRnn(input_size, output_size, hidden_size, num_layers).to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For optional LSTM-GAN, see https://discuss.pytorch.org/t/how-to-use-lstm-to-construct-gan/12419\n",
|
|
"\n",
|
|
"Or VAE (variational Auto encoder):\n",
|
|
"\n",
|
|
"> The only constraint on the latent vector representation for traditional autoencoders is that latent vectors should be easily decodable back into the original image. As a result, the latent space $Z$ can become disjoint and non-continuous. Variational autoencoders try to solve this problem. [Alexander van de Kleut](https://avandekleut.github.io/vae/)\n",
|
|
"\n",
|
|
"For LSTM based generative VAE: https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
|
|
"\n",
|
|
"http://web.archive.org/web/20210119121802/https://towardsdatascience.com/time-series-generation-with-vae-lstm-5a6426365a1c?gi=29d8b029a386\n",
|
|
"\n",
|
|
"https://youtu.be/qJeaCHQ1k2w?si=30aAdqqwvz0DpR-x&t=687 VAE generate mu and sigma of a Normal distribution. Thus, they don't map the input to a single point, but a gausian distribution."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 328,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTMModel(nn.Module):\n",
|
|
" # input_size : number of features in input at each time step\n",
|
|
" # hidden_size : Number of LSTM units \n",
|
|
" # num_layers : number of LSTM layers \n",
|
|
" def __init__(self, input_size, hidden_size, num_layers): \n",
|
|
" super(LSTMModel, self).__init__() #initializes the parent class nn.Module\n",
|
|
" # We _could_ train the h0: https://discuss.pytorch.org/t/learn-initial-hidden-state-h0-for-rnn/10013 \n",
|
|
" # self.lin1 = nn.Linear(input_size, hidden_size)\n",
|
|
" self.num_layers = num_layers\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n",
|
|
" self.linear = nn.Linear(hidden_size, output_size)\n",
|
|
" # self.activation_v = nn.LeakyReLU(.01)\n",
|
|
" # self.activation_heading = torch.remainder()\n",
|
|
"\n",
|
|
" \n",
|
|
" def get_hidden_state(self, batch_size, device):\n",
|
|
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" return (h, c)\n",
|
|
"\n",
|
|
" def forward(self, x, hidden_state): # defines forward pass of the neural network\n",
|
|
" # out = self.lin1(x)\n",
|
|
" \n",
|
|
" out, hidden_state = self.lstm(x, hidden_state)\n",
|
|
" # extract only the last time step, see https://machinelearningmastery.com/lstm-for-time-series-prediction-in-pytorch/\n",
|
|
" # print(out.shape)\n",
|
|
" # TODO)) Might want to remove this below: as it might improve training\n",
|
|
" # out = out[:, -1,:]\n",
|
|
" # print(out.shape)\n",
|
|
" out = self.linear(out)\n",
|
|
" \n",
|
|
" # torch.remainder(out[1], 360)\n",
|
|
" # print('o',out.shape)\n",
|
|
" return out, hidden_state\n",
|
|
"\n",
|
|
"lstm = LSTMModel(input_size, hidden_size, num_layers).to(device)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 329,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# model = rnn\n",
|
|
"model = lstm\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 330,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
|
"loss_fn = nn.MSELoss()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 331,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate():\n",
|
|
" # toggle evaluation mode\n",
|
|
" model.eval()\n",
|
|
" with torch.no_grad():\n",
|
|
" batch_size, seq_len, feature_dim = X_train.shape\n",
|
|
" y_pred, _ = model(\n",
|
|
" X_train.to(device=device),\n",
|
|
" model.get_hidden_state(batch_size, device)\n",
|
|
" )\n",
|
|
" train_rmse = torch.sqrt(loss_fn(y_pred, y_train))\n",
|
|
" # print(y_pred)\n",
|
|
"\n",
|
|
" batch_size, seq_len, feature_dim = X_test.shape\n",
|
|
" y_pred, _ = model(\n",
|
|
" X_test.to(device=device),\n",
|
|
" model.get_hidden_state(batch_size, device)\n",
|
|
" )\n",
|
|
" # print(loss_fn(y_pred, y_test))\n",
|
|
" test_rmse = torch.sqrt(loss_fn(y_pred, y_test))\n",
|
|
" print(\"Epoch ??: train RMSE %.4f, test RMSE %.4f\" % ( train_rmse, test_rmse))\n",
|
|
"\n",
|
|
"def load_most_recent():\n",
|
|
" paths = list(cache_path.glob(f\"checkpoint-{model._get_name()}_*.pt\"))\n",
|
|
" if len(paths) < 1:\n",
|
|
" print('Nothing found to load')\n",
|
|
" return None, None\n",
|
|
" paths.sort()\n",
|
|
"\n",
|
|
" print(f\"Loading {paths[-1]}\")\n",
|
|
" return load_cache(path=paths[-1])\n",
|
|
"\n",
|
|
"def load_cache(epoch=None, path=None):\n",
|
|
" if path is None:\n",
|
|
" if epoch is None:\n",
|
|
" raise RuntimeError(\"Either path or epoch must be given\")\n",
|
|
" path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n",
|
|
" else:\n",
|
|
" print (path.stem)\n",
|
|
" epoch = int(path.stem[-5:])\n",
|
|
"\n",
|
|
" cached = torch.load(path)\n",
|
|
" \n",
|
|
" optimizer.load_state_dict(cached['optimizer_state_dict'])\n",
|
|
" model.load_state_dict(cached['model_state_dict'])\n",
|
|
" return epoch, cached['loss']\n",
|
|
" \n",
|
|
"\n",
|
|
"def cache(epoch, loss):\n",
|
|
" path = cache_path / f\"checkpoint-{model._get_name()}_{epoch:05d}.pt\"\n",
|
|
" print(f\"Cache to {path}\")\n",
|
|
" torch.save({\n",
|
|
" 'epoch': epoch,\n",
|
|
" 'model_state_dict': model.state_dict(),\n",
|
|
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
|
" 'loss': loss,\n",
|
|
" }, path)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"TODO)) See [this notebook](https://www.cs.toronto.edu/~lczhang/aps360_20191/lec/w08/rnn.html) For initialization (with random or not) and the use of GRU"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 332,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading EXPERIMENTS/cache/hof2/checkpoint-LSTMModel_01000.pt\n",
|
|
"checkpoint-LSTMModel_01000\n",
|
|
"starting from epoch 1000 (loss: 0.014368701726198196)\n",
|
|
"Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0it [00:00, ?it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch ??: train RMSE 0.0849, test RMSE 0.0866\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"start_epoch, loss = load_most_recent()\n",
|
|
"if start_epoch is None:\n",
|
|
" start_epoch = 0\n",
|
|
"else:\n",
|
|
" print(f\"starting from epoch {start_epoch} (loss: {loss})\")\n",
|
|
" evaluate()\n",
|
|
"\n",
|
|
"loss_log = []\n",
|
|
"# Train Network\n",
|
|
"for epoch in tqdm(range(start_epoch+1,num_epochs+1)):\n",
|
|
" # toggle train mode\n",
|
|
" model.train()\n",
|
|
" for batch_idx, (x, targets) in enumerate(loader_train):\n",
|
|
" # Get x to cuda if possible\n",
|
|
" x = x.to(device=device).squeeze(1)\n",
|
|
" targets = targets.to(device=device)\n",
|
|
"\n",
|
|
" # forward\n",
|
|
" scores, _ = model(\n",
|
|
" x,\n",
|
|
" torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device),\n",
|
|
" torch.zeros(num_layers, x.shape[2], hidden_size, dtype=torch.float).to(device=device)\n",
|
|
" )\n",
|
|
" # print(scores)\n",
|
|
" loss = loss_fn(scores, targets)\n",
|
|
"\n",
|
|
" # backward\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" # gradient descent update step/adam step\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" loss_log.append(loss.item())\n",
|
|
"\n",
|
|
" if epoch % 5 != 0:\n",
|
|
" continue\n",
|
|
"\n",
|
|
" cache(epoch, loss)\n",
|
|
" evaluate()\n",
|
|
"\n",
|
|
"evaluate()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 333,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# print(loss)\n",
|
|
"# print(len(loss_log))\n",
|
|
"# plt.plot(loss_log)\n",
|
|
"# plt.ylabel('Loss')\n",
|
|
"# plt.xlabel('iteration')\n",
|
|
"# plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 335,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([49999, 9, 2]) torch.Size([49999, 9, 2])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" y_pred, _ = model(X_train.to(device=device),\n",
|
|
" model.get_hidden_state(X_train.shape[0], device))\n",
|
|
" \n",
|
|
" print(y_pred.shape, y_train.shape)\n",
|
|
"# y_train, y_pred"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 336,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import scipy\n",
|
|
"\n",
|
|
"def ceil_away_from_0(a):\n",
|
|
" return np.sign(a) * np.ceil(np.abs(a))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 343,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict_and_plot(model, feature, steps = 50):\n",
|
|
" lenght = feature.shape[0]\n",
|
|
"\n",
|
|
" dt = (1/ FPS) * SAMPLE_STEP\n",
|
|
"\n",
|
|
" trajectory = feature\n",
|
|
"\n",
|
|
" # feature = filtered_data.loc[_track_id,:].iloc[:5][in_fields].values\n",
|
|
" # nxt = filtered_data.loc[_track_id,:].iloc[5][out_fields]\n",
|
|
" with torch.no_grad():\n",
|
|
" # h = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n",
|
|
" # c = torch.zeros(num_layers, window+1, hidden_size, dtype=torch.float).to(device=device)\n",
|
|
" h = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n",
|
|
" c = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float).to(device=device)\n",
|
|
" hidden_state = (h, c)\n",
|
|
" # X = torch.tensor([feature], dtype=torch.float).to(device)\n",
|
|
" # y, (h, c) = model(X, h, c)\n",
|
|
" for i in range(steps):\n",
|
|
" # predict_f = scipy.ndimage.uniform_filter(feature)\n",
|
|
" # predict_f = scipy.interpolate.splrep(feature[:][0], feature[:][1],)\n",
|
|
" # predict_f = scipy.signal.spline_feature(feature, lmbda=.1)\n",
|
|
" # bathc size of one, so feature as single item in array\n",
|
|
" # print(X.shape)\n",
|
|
" X = torch.tensor([feature], dtype=torch.float).to(device)\n",
|
|
" # print(type(model))\n",
|
|
" y, hidden_state, *_ = model(X, hidden_state)\n",
|
|
" # print(hidden_state.shape)\n",
|
|
"\n",
|
|
" s = y[-1][-1].cpu()\n",
|
|
"\n",
|
|
" # proj_x proj_y v heading a d_heading\n",
|
|
" # next_step = feature\n",
|
|
"\n",
|
|
" dx, dy = s\n",
|
|
" \n",
|
|
" dx = (dx * GRID_SIZE).round() / GRID_SIZE\n",
|
|
" dy = (dy * GRID_SIZE).round() / GRID_SIZE\n",
|
|
" vx, vy = dx / dt, dy / dt\n",
|
|
"\n",
|
|
" v = np.sqrt(s[0]**2 + s[1]**2)\n",
|
|
" heading = (np.arctan2(s[1], s[0]) * 180 / np.pi) % 360\n",
|
|
" # a = (v - feature[-1][2]) / dt\n",
|
|
" ax = (vx - feature[-1][2]) / dt\n",
|
|
" ay = (vx - feature[-1][3]) / dt\n",
|
|
" # d_heading = (heading - feature[-1][5])\n",
|
|
" # print(s)\n",
|
|
" # ['x', 'y', 'vx', 'vy', 'ax', 'ay'] \n",
|
|
" x = feature[-1][0] + dx\n",
|
|
" y = feature[-1][1] + dy\n",
|
|
" if GRID_SIZE is not None:\n",
|
|
" # put points back on grid\n",
|
|
" x = (x*GRID_SIZE).round() / GRID_SIZE\n",
|
|
" y = (y*GRID_SIZE).round() / GRID_SIZE\n",
|
|
"\n",
|
|
" feature = [[x, y, vx, vy, ax, ay]]\n",
|
|
" \n",
|
|
" trajectory = np.append(trajectory, feature, axis=0)\n",
|
|
" # f = [feature[-1][0] + s[0]*dt, feature[-1][1] + s[1]*dt, v, heading, a, d_heading ]\n",
|
|
" # feature = np.append(feature, [feature], axis=0)\n",
|
|
" \n",
|
|
" # print(next_step, nxt)\n",
|
|
" # print(trajectory)\n",
|
|
" plt.plot(trajectory[:lenght,0], trajectory[:lenght,1], c='orange')\n",
|
|
" plt.plot(trajectory[lenght-1:,0], trajectory[lenght-1:,1], c='red')\n",
|
|
" plt.scatter(trajectory[lenght:,0], trajectory[lenght:,1], c='red', marker='x')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1301\n",
|
|
"(10, 6) (10, 6)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"# print(filtered_data.loc[track_id,:]['proj_x'])\n",
|
|
"_track_id = 8701 # random.choice(track_ids)\n",
|
|
"_track_id = 3880 # random.choice(track_ids)\n",
|
|
"\n",
|
|
"# _track_id = 2780\n",
|
|
"\n",
|
|
"for batch_idx in range(100):\n",
|
|
" _track_id = random.choice(track_ids)\n",
|
|
" plt.plot(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y'],\n",
|
|
" c='grey', alpha=.2\n",
|
|
" )\n",
|
|
"\n",
|
|
"_track_id = random.choice(track_ids)\n",
|
|
"# _track_id = 1096\n",
|
|
"_track_id = 1301\n",
|
|
"print(_track_id)\n",
|
|
"ax = plt.scatter(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y'],\n",
|
|
" marker=\"*\") \n",
|
|
"plt.plot(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y']\n",
|
|
")\n",
|
|
"\n",
|
|
"X = data.loc[_track_id,:].iloc[:][in_fields].values\n",
|
|
"# Adding randomness might be a cheat to get multiple features from current position\n",
|
|
"rnd = np.random.random_sample(X.shape) / 10\n",
|
|
"# print(rnd)\n",
|
|
"\n",
|
|
"print(X[:10].shape, (X[:10] + rnd[:10]).shape)\n",
|
|
"\n",
|
|
"# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
|
|
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n",
|
|
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n",
|
|
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:20][in_fields].values)\n",
|
|
"# predict_and_plot(model, data.loc[_track_id,:].iloc[:30][in_fields].values)\n",
|
|
"predict_and_plot(model, X[:12])\n",
|
|
"predict_and_plot(model, X[:12] + rnd[:12])\n",
|
|
"predict_and_plot(model, data.loc[_track_id,:].iloc[:][in_fields].values)\n",
|
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
|
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## VAE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# From https://github.com/CUN-bjy/lstm-vae-torch/blob/main/src/models.py (MIT)\n",
|
|
"\n",
|
|
"from typing import Optional\n",
|
|
"\n",
|
|
"\n",
|
|
"class Encoder(nn.Module):\n",
|
|
" def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):\n",
|
|
" super(Encoder, self).__init__()\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
" self.num_layers = num_layers\n",
|
|
" self.lstm = nn.LSTM(\n",
|
|
" input_size,\n",
|
|
" hidden_size,\n",
|
|
" num_layers,\n",
|
|
" batch_first=True,\n",
|
|
" bidirectional=False,\n",
|
|
" )\n",
|
|
"\n",
|
|
" def get_hidden_state(self, batch_size, device):\n",
|
|
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" return (h, c)\n",
|
|
"\n",
|
|
" def forward(self, x, hidden_state):\n",
|
|
" # x: tensor of shape (batch_size, seq_length, hidden_size)\n",
|
|
" outputs, (hidden, cell) = self.lstm(x, hidden_state)\n",
|
|
" return outputs, (hidden, cell)\n",
|
|
"\n",
|
|
"\n",
|
|
"class Decoder(nn.Module):\n",
|
|
" def __init__(\n",
|
|
" self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2\n",
|
|
" ):\n",
|
|
" super(Decoder, self).__init__()\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
" self.output_size = output_size\n",
|
|
" self.num_layers = num_layers\n",
|
|
" self.lstm = nn.LSTM(\n",
|
|
" input_size,\n",
|
|
" hidden_size,\n",
|
|
" num_layers,\n",
|
|
" batch_first=True,\n",
|
|
" bidirectional=False,\n",
|
|
" )\n",
|
|
" self.fc = nn.Linear(hidden_size, output_size)\n",
|
|
" \n",
|
|
" def get_hidden_state(self, batch_size, device):\n",
|
|
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)\n",
|
|
" return (h, c)\n",
|
|
"\n",
|
|
" def forward(self, x, hidden):\n",
|
|
" # x: tensor of shape (batch_size, seq_length, hidden_size)\n",
|
|
" output, (hidden, cell) = self.lstm(x, hidden)\n",
|
|
" prediction = self.fc(output)\n",
|
|
" return prediction, (hidden, cell)\n",
|
|
"\n",
|
|
"\n",
|
|
"class LSTMVAE(nn.Module):\n",
|
|
" \"\"\"LSTM-based Variational Auto Encoder\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self, input_size, output_size, hidden_size, latent_size, device=torch.device(\"cuda\")\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" input_size: int, batch_size x sequence_length x input_dim\n",
|
|
" hidden_size: int, output size of LSTM AE\n",
|
|
" latent_size: int, latent z-layer size\n",
|
|
" num_lstm_layer: int, number of layers in LSTM\n",
|
|
" \"\"\"\n",
|
|
" super(LSTMVAE, self).__init__()\n",
|
|
" self.device = device\n",
|
|
"\n",
|
|
" # dimensions\n",
|
|
" self.input_size = input_size\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
" self.latent_size = latent_size\n",
|
|
" self.num_layers = 1\n",
|
|
"\n",
|
|
" # lstm ae\n",
|
|
" self.lstm_enc = Encoder(\n",
|
|
" input_size=input_size, hidden_size=hidden_size, num_layers=self.num_layers\n",
|
|
" )\n",
|
|
" self.lstm_dec = Decoder(\n",
|
|
" input_size=latent_size,\n",
|
|
" output_size=output_size,\n",
|
|
" hidden_size=hidden_size,\n",
|
|
" num_layers=self.num_layers,\n",
|
|
" )\n",
|
|
"\n",
|
|
" self.fc21 = nn.Linear(self.hidden_size, self.latent_size)\n",
|
|
" self.fc22 = nn.Linear(self.hidden_size, self.latent_size)\n",
|
|
" self.fc3 = nn.Linear(self.latent_size, self.hidden_size)\n",
|
|
"\n",
|
|
" def reparametize(self, mu, logvar):\n",
|
|
" std = torch.exp(0.5 * logvar)\n",
|
|
" noise = torch.randn_like(std).to(self.device)\n",
|
|
"\n",
|
|
" z = mu + noise * std\n",
|
|
" return z\n",
|
|
"\n",
|
|
" def forward(self, x, hidden_state_encoder: Optional[tuple]=None):\n",
|
|
" batch_size, seq_len, feature_dim = x.shape\n",
|
|
"\n",
|
|
" if hidden_state_encoder is None:\n",
|
|
" hidden_state_encoder = self.lstm_enc.get_hidden_state(batch_size, self.device)\n",
|
|
"\n",
|
|
" # encode input space to hidden space\n",
|
|
" prediction, hidden_state = self.lstm_enc(x, hidden_state_encoder)\n",
|
|
" enc_h = hidden_state[0].view(batch_size, self.hidden_size).to(self.device)\n",
|
|
"\n",
|
|
" # extract latent variable z(hidden space to latent space)\n",
|
|
" mean = self.fc21(enc_h)\n",
|
|
" logvar = self.fc22(enc_h)\n",
|
|
" z = self.reparametize(mean, logvar) # batch_size x latent_size\n",
|
|
"\n",
|
|
" # initialize hidden state as inputs\n",
|
|
" h_ = self.fc3(z)\n",
|
|
" \n",
|
|
" # decode latent space to input space\n",
|
|
" z = z.repeat(1, seq_len, 1)\n",
|
|
" z = z.view(batch_size, seq_len, self.latent_size).to(self.device)\n",
|
|
"\n",
|
|
" # initialize hidden state\n",
|
|
" hidden = (h_.contiguous(), h_.contiguous())\n",
|
|
" # TODO)) the above is not the right dimensions, but this changes architecture\n",
|
|
" hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\n",
|
|
" reconstruct_output, hidden = self.lstm_dec(z, hidden)\n",
|
|
"\n",
|
|
" x_hat = reconstruct_output\n",
|
|
" \n",
|
|
" return x_hat, hidden_state, mean, logvar\n",
|
|
"\n",
|
|
" def loss_function(self, *args, **kwargs) -> dict:\n",
|
|
" \"\"\"\n",
|
|
" Computes the VAE loss function.\n",
|
|
" KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n",
|
|
" :param args:\n",
|
|
" :param kwargs:\n",
|
|
" :return:\n",
|
|
" \"\"\"\n",
|
|
" predicted = args[0]\n",
|
|
" target = args[1]\n",
|
|
" mu = args[2]\n",
|
|
" log_var = args[3]\n",
|
|
"\n",
|
|
" kld_weight = 0.00025 # Account for the minibatch samples from the dataset\n",
|
|
" recons_loss = torch.nn.functional.mse_loss(predicted, target=target)\n",
|
|
"\n",
|
|
" kld_loss = torch.mean(\n",
|
|
" -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0\n",
|
|
" )\n",
|
|
"\n",
|
|
" loss = recons_loss + kld_weight * kld_loss\n",
|
|
" return {\n",
|
|
" \"loss\": loss,\n",
|
|
" \"Reconstruction_Loss\": recons_loss.detach(),\n",
|
|
" \"KLD\": -kld_loss.detach(),\n",
|
|
" }\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 303,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vae = LSTMVAE(input_size, output_size, hidden_size, 1024, device=device)\n",
|
|
"vae.to(device)\n",
|
|
"vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 304,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_vae(model, train_loader, test_loader, max_iter, learning_rate):\n",
|
|
" # optimizer\n",
|
|
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
|
"\n",
|
|
" ## interation setup\n",
|
|
" epochs = tqdm(range(max_iter // len(train_loader) + 1))\n",
|
|
" epochs = tqdm(range(max_iter))\n",
|
|
"\n",
|
|
" ## training\n",
|
|
" count = 0\n",
|
|
" for epoch in epochs:\n",
|
|
" model.train()\n",
|
|
" optimizer.zero_grad()\n",
|
|
" train_iterator = tqdm(\n",
|
|
" enumerate(train_loader), total=len(train_loader), desc=\"training\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" # if count > max_iter:\n",
|
|
" # print(\"max!\")\n",
|
|
" # return model\n",
|
|
"\n",
|
|
" for batch_idx, (x, targets) in train_iterator:\n",
|
|
" count += 1\n",
|
|
"\n",
|
|
" # future_data, past_data = batch_data\n",
|
|
"\n",
|
|
" ## reshape\n",
|
|
" batch_size = x.size(0)\n",
|
|
" example_size = x.size(1)\n",
|
|
" # image_size = past_data.size(2), past_data.size(3)\n",
|
|
" # past_data = (\n",
|
|
" # past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\n",
|
|
" # )\n",
|
|
" # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n",
|
|
"\n",
|
|
" y, hidden_state, mean, logvar = model(x)\n",
|
|
"\n",
|
|
" # calculate vae loss\n",
|
|
" # print(y.shape, targets.shape)\n",
|
|
" losses = model.loss_function(y, targets, mean, logvar)\n",
|
|
" mloss, recon_loss, kld_loss = (\n",
|
|
" losses[\"loss\"],\n",
|
|
" losses[\"Reconstruction_Loss\"],\n",
|
|
" losses[\"KLD\"],\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Backward and optimize\n",
|
|
" optimizer.zero_grad()\n",
|
|
" mloss.mean().backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" train_iterator.set_postfix({\"train_loss\": float(mloss.mean())})\n",
|
|
" print(\"train_loss\", float(mloss.mean()), epoch)\n",
|
|
"\n",
|
|
" model.eval()\n",
|
|
" eval_loss = 0\n",
|
|
" test_iterator = tqdm(\n",
|
|
" enumerate(test_loader), total=len(test_loader), desc=\"testing\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" with torch.no_grad():\n",
|
|
" for batch_idx, (x, targets) in test_iterator:\n",
|
|
" # future_data, past_data = batch_data\n",
|
|
"\n",
|
|
" ## reshape\n",
|
|
" batch_size = x.size(0)\n",
|
|
" example_size = x.size(1)\n",
|
|
" # past_data = (\n",
|
|
" # past_data.view(batch_size, example_size, -1).float().to(args.device)\n",
|
|
" # )\n",
|
|
" # future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\n",
|
|
"\n",
|
|
" y, hidden_state, mean, logvar = model(x)\n",
|
|
"\n",
|
|
" # calculate vae loss\n",
|
|
" losses = model.loss_function(y, targets, mean, logvar)\n",
|
|
" mloss, recon_loss, kld_loss = (\n",
|
|
" losses[\"loss\"],\n",
|
|
" losses[\"Reconstruction_Loss\"],\n",
|
|
" losses[\"KLD\"],\n",
|
|
" )\n",
|
|
"\n",
|
|
" eval_loss += mloss.mean().item()\n",
|
|
"\n",
|
|
" test_iterator.set_postfix({\"eval_loss\": float(mloss.mean())})\n",
|
|
"\n",
|
|
" # if batch_idx == 0:\n",
|
|
" # nhw_orig = past_data[0].view(example_size, image_size[0], -1)\n",
|
|
" # nhw_recon = recon_x[0].view(example_size, image_size[0], -1)\n",
|
|
" # imshow(nhw_orig.cpu(), f\"orig{epoch}\")\n",
|
|
" # imshow(nhw_recon.cpu(), f\"recon{epoch}\")\n",
|
|
" # writer.add_images(f\"original{i}\", nchw_orig, epoch)\n",
|
|
" # writer.add_images(f\"reconstructed{i}\", nchw_recon, epoch)\n",
|
|
"\n",
|
|
" eval_loss = eval_loss / len(test_loader)\n",
|
|
" # writer.add_scalar(\"eval_loss\", float(eval_loss), epoch)\n",
|
|
" print(\"Evaluation Score : [{}]\".format(eval_loss))\n",
|
|
"\n",
|
|
" print(\"Done :-)\")\n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 305,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/11 [00:00<?, ?it/s]\n",
|
|
"training: 0%| | 0/98 [00:00<?, ?it/s]\n",
|
|
" 0%| | 0/1000 [00:00<?, ?it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "RuntimeError",
|
|
"evalue": "For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[305], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain_vae\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|
"Cell \u001b[0;32mIn[304], line 36\u001b[0m, in \u001b[0;36mtrain_vae\u001b[0;34m(model, train_loader, test_loader, max_iter, learning_rate)\u001b[0m\n\u001b[1;32m 29\u001b[0m example_size \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# image_size = past_data.size(2), past_data.size(3)\u001b[39;00m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;66;03m# past_data = (\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# past_data.view(batch_size, example_size, -1).float().to(args.device) # flattens image, we don't need this\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m# future_data = future_data.view(batch_size, example_size, -1).float().to(args.device)\u001b[39;00m\n\u001b[0;32m---> 36\u001b[0m y, hidden_state, mean, logvar \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# calculate vae loss\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# print(y.shape, targets.shape)\u001b[39;00m\n\u001b[1;32m 40\u001b[0m losses \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mloss_function(y, targets, mean, logvar)\n",
|
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
|
"Cell \u001b[0;32mIn[302], line 128\u001b[0m, in \u001b[0;36mLSTMVAE.forward\u001b[0;34m(self, x, hidden_state_encoder)\u001b[0m\n\u001b[1;32m 125\u001b[0m hidden \u001b[38;5;241m=\u001b[39m (h_\u001b[38;5;241m.\u001b[39mcontiguous(), h_\u001b[38;5;241m.\u001b[39mcontiguous())\n\u001b[1;32m 126\u001b[0m \u001b[38;5;66;03m# TODO)) the above is not the right dimensions, but this changes architecture\u001b[39;00m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;66;03m# hidden = self.lstm_dec.get_hidden_state(batch_size, self.device)\u001b[39;00m\n\u001b[0;32m--> 128\u001b[0m reconstruct_output, hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm_dec\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 130\u001b[0m x_hat \u001b[38;5;241m=\u001b[39m reconstruct_output\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x_hat, hidden_state, mean, logvar\n",
|
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
|
"Cell \u001b[0;32mIn[302], line 54\u001b[0m, in \u001b[0;36mDecoder.forward\u001b[0;34m(self, x, hidden)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, hidden):\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# x: tensor of shape (batch_size, seq_length, hidden_size)\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m output, (hidden, cell) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m prediction \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc(output)\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m prediction, (hidden, cell)\n",
|
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/rnn.py:755\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 752\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m3\u001b[39m):\n\u001b[1;32m 753\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor batched 3-D input, hx and cx should \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 754\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124malso be 3-D but got (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-D) tensors\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 755\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg)\n\u001b[1;32m 756\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 757\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hx[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m hx[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n",
|
|
"\u001b[0;31mRuntimeError\u001b[0m: For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"train_vae(vae, loader_train, loader_test, num_epochs, learning_rate*10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 300,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1301\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# VAE Predict and plot\n",
|
|
"_track_id = 8701 # random.choice(track_ids)\n",
|
|
"_track_id = 3880 # random.choice(track_ids)\n",
|
|
"\n",
|
|
"# _track_id = 2780\n",
|
|
"\n",
|
|
"for batch_idx in range(100):\n",
|
|
" _track_id = random.choice(track_ids)\n",
|
|
" plt.plot(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y'],\n",
|
|
" c='grey', alpha=.2\n",
|
|
" )\n",
|
|
"\n",
|
|
"_track_id = random.choice(track_ids)\n",
|
|
"# _track_id = 1096\n",
|
|
"_track_id = 1301\n",
|
|
"print(_track_id)\n",
|
|
"ax = plt.scatter(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y'],\n",
|
|
" marker=\"*\") \n",
|
|
"plt.plot(\n",
|
|
" data.loc[_track_id,:]['x'],\n",
|
|
" data.loc[_track_id,:]['y']\n",
|
|
")\n",
|
|
"\n",
|
|
"# predict_and_plot(data.loc[_track_id,:].iloc[:5][in_fields].values)\n",
|
|
"predict_and_plot(vae, data.loc[_track_id,:].iloc[:5][in_fields].values, 50)\n",
|
|
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:10][in_fields].values, 50)\n",
|
|
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:20][in_fields].values)\n",
|
|
"# predict_and_plot(vae, data.loc[_track_id,:].iloc[:30][in_fields].values)\n",
|
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:70][in_fields].values)\n",
|
|
"# predict_and_plot(filtered_data.loc[_track_id,:].iloc[:115][in_fields].values)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "AttributeError",
|
|
"evalue": "'LSTM_VAE' object has no attribute 'embed_size'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[103], line 268\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# Statistics.\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# if batch_num % 20 ==0:\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\u001b[39;00m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m test_losses\n\u001b[0;32m--> 268\u001b[0m vae \u001b[38;5;241m=\u001b[39m \u001b[43mLSTM_VAE\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 270\u001b[0m vae_loss \u001b[38;5;241m=\u001b[39m VAE_Loss()\n\u001b[1;32m 271\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(vae\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m learning_rate)\n",
|
|
"Cell \u001b[0;32mIn[103], line 73\u001b[0m, in \u001b[0;36mLSTM_VAE.__init__\u001b[0;34m(self, input_size, output_size, hidden_size, latent_size, num_layers, device)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# Decoder Part\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minit_hidden_decoder \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlatent_size, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor)\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder_lstm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLSTM(input_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_size\u001b[49m, hidden_size\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size, batch_first \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m, num_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers)\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mLinear(in_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlstm_factor, out_features\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_size)\n",
|
|
"File \u001b[0;32m~/suspicion/trap/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n",
|
|
"\u001b[0;31mAttributeError\u001b[0m: 'LSTM_VAE' object has no attribute 'embed_size'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# import torch\n",
|
|
"\n",
|
|
"class VAE_Loss(torch.nn.Module):\n",
|
|
" \"\"\"\n",
|
|
" Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
|
|
" \"\"\"\n",
|
|
" def __init__(self):\n",
|
|
" super(VAE_Loss, self).__init__()\n",
|
|
" self.nlloss = torch.nn.NLLLoss()\n",
|
|
" \n",
|
|
" def KL_loss (self, mu, log_var, z):\n",
|
|
" kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
|
|
" kl = kl.sum(-1) # to go from multi-dimensional z to single dimensional z : (batch_size x latent_size) ---> (batch_size) \n",
|
|
" # i.e Z = [ [z1_1, z1_2 , ...., z1_lt] ] ------> z = [ z1] \n",
|
|
" # [ [z2_1, z2_2, ....., z2_lt] ] [ z2]\n",
|
|
" # . [ . ]\n",
|
|
" # . [ . ]\n",
|
|
" # [[zn_1, zn_2, ....., zn_lt] ] [ zn]\n",
|
|
" \n",
|
|
" # lt=latent_size \n",
|
|
" kl = kl.mean()\n",
|
|
" \n",
|
|
" return kl\n",
|
|
"\n",
|
|
" def reconstruction_loss(self, x_hat_param, x):\n",
|
|
"\n",
|
|
" x = x.view(-1).contiguous()\n",
|
|
" x_hat_param = x_hat_param.view(-1, x_hat_param.size(2))\n",
|
|
"\n",
|
|
" recon = self.nlloss(x_hat_param, x)\n",
|
|
"\n",
|
|
" return recon\n",
|
|
" \n",
|
|
"\n",
|
|
" def forward(self, mu, log_var,z, x_hat_param, x):\n",
|
|
" kl_loss = self.KL_loss(mu, log_var, z)\n",
|
|
" recon_loss = self.reconstruction_loss(x_hat_param, x)\n",
|
|
"\n",
|
|
"\n",
|
|
" elbo = kl_loss + recon_loss # we use + because recon loss is a NLLoss (cross entropy) and it's negative in its own, and in the ELBO equation we have\n",
|
|
" # elbo = KL_loss - recon_loss, therefore, ELBO = KL_loss - (NLLoss) = KL_loss + NLLoss\n",
|
|
"\n",
|
|
" return elbo, kl_loss, recon_loss\n",
|
|
" \n",
|
|
"class LSTM_VAE(torch.nn.Module):\n",
|
|
" \"\"\"\n",
|
|
" Adapted from https://github.com/Khamies/LSTM-Variational-AutoEncoder/blob/main/model.py\n",
|
|
" \"\"\"\n",
|
|
" def __init__(self, input_size, output_size, hidden_size, latent_size, num_layers=1, device=\"cuda\"):\n",
|
|
" super(LSTM_VAE, self).__init__()\n",
|
|
"\n",
|
|
" self.device = device\n",
|
|
" \n",
|
|
" # Variables\n",
|
|
" self.num_layers = num_layers\n",
|
|
" self.lstm_factor = num_layers\n",
|
|
" self.input_size = input_size\n",
|
|
" self.hidden_size = hidden_size\n",
|
|
" self.latent_size = latent_size\n",
|
|
" self.output_size = output_size\n",
|
|
"\n",
|
|
" # X: bsz * seq_len * vocab_size \n",
|
|
" # X: bsz * seq_len * embed_size\n",
|
|
"\n",
|
|
" # Encoder Part\n",
|
|
" self.encoder_lstm = torch.nn.LSTM(input_size= input_size,hidden_size= self.hidden_size, batch_first=True, num_layers= self.num_layers)\n",
|
|
" self.mean = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n",
|
|
" self.log_variance = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)\n",
|
|
"\n",
|
|
" # Decoder Part\n",
|
|
" \n",
|
|
" self.hidden_decoder_linear = torch.nn.Linear(in_features= self.latent_size, out_features= self.hidden_size * self.lstm_factor)\n",
|
|
" self.decoder_lstm = torch.nn.LSTM(input_size= self.embed_size, hidden_size= self.hidden_size, batch_first = True, num_layers = self.num_layers)\n",
|
|
" self.output = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.output_size)\n",
|
|
" # self.log_softmax = torch.nn.LogSoftmax(dim=2)\n",
|
|
"\n",
|
|
" def get_hidden_state(self, batch_size):\n",
|
|
" h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n",
|
|
" c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)\n",
|
|
" return (h, c)\n",
|
|
"\n",
|
|
"\n",
|
|
" def encoder(self, x, hidden_state):\n",
|
|
"\n",
|
|
" # pad the packed input.\n",
|
|
"\n",
|
|
" out, (h,c) = self.encoder_lstm(x, hidden_state)\n",
|
|
" # output_encoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_encoder, batch_first=True, total_length= total_padding_length)\n",
|
|
"\n",
|
|
" # Extimate the mean and the variance of q(z|x)\n",
|
|
" mean = self.mean(h)\n",
|
|
" log_var = self.log_variance(h)\n",
|
|
" std = torch.exp(0.5 * log_var) # e^(0.5 log_var) = var^0.5\n",
|
|
" \n",
|
|
" # Generate a unit gaussian noise.\n",
|
|
" # batch_size = output_encoder.size(0)\n",
|
|
" # seq_len = output_encoder.size(1)\n",
|
|
" # noise = torch.randn(batch_size, self.latent_size).to(self.device)\n",
|
|
" noise = torch.randn(self.latent_size).to(self.device)\n",
|
|
" \n",
|
|
" z = noise * std + mean\n",
|
|
"\n",
|
|
" return z, mean, log_var, (h,c)\n",
|
|
"\n",
|
|
"\n",
|
|
" def decoder(self, z, x):\n",
|
|
"\n",
|
|
" hidden_decoder = self.hidden_decoder_linear(z)\n",
|
|
" hidden_decoder = (hidden_decoder, hidden_decoder)\n",
|
|
"\n",
|
|
" # pad the packed input.\n",
|
|
" packed_output_decoder, hidden_decoder = self.decoder_lstm(packed_x_embed,hidden_decoder) \n",
|
|
" output_decoder, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_decoder, batch_first=True, total_length= total_padding_length)\n",
|
|
"\n",
|
|
"\n",
|
|
" x_hat = self.output(output_decoder)\n",
|
|
" \n",
|
|
" # x_hat = self.log_softmax(x_hat)\n",
|
|
"\n",
|
|
"\n",
|
|
" return x_hat\n",
|
|
"\n",
|
|
" \n",
|
|
"\n",
|
|
" def forward(self, x, hidden_state):\n",
|
|
" \n",
|
|
" \"\"\"\n",
|
|
" x : bsz * seq_len\n",
|
|
" \n",
|
|
" hidden_encoder: ( num_lstm_layers * bsz * hidden_size, num_lstm_layers * bsz * hidden_size)\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
" # Get Embeddings\n",
|
|
" # x_embed, maximum_padding_length = self.get_embedding(x)\n",
|
|
"\n",
|
|
" # Packing the input\n",
|
|
" # packed_x_embed = torch.nn.utils.rnn.pack_padded_sequence(input= x_embed, lengths= sentences_length, batch_first=True, enforce_sorted=False)\n",
|
|
"\n",
|
|
"\n",
|
|
" # Encoder\n",
|
|
" z, mean, log_var, hidden_encoder = self.encoder(x, maximum_padding_length, hidden_encoder)\n",
|
|
"\n",
|
|
" # Decoder\n",
|
|
" x_hat = self.decoder(z, packed_x_embed, maximum_padding_length)\n",
|
|
" \n",
|
|
" return x_hat, mean, log_var, z, hidden_encoder\n",
|
|
"\n",
|
|
" \n",
|
|
"\n",
|
|
" def inference(self, n_samples, x, z):\n",
|
|
"\n",
|
|
" # generate random z \n",
|
|
" batch_size = 1\n",
|
|
" seq_len = 1\n",
|
|
" idx_sample = []\n",
|
|
"\n",
|
|
"\n",
|
|
" hidden = self.hidden_decoder_linear(z)\n",
|
|
" hidden = (hidden, hidden)\n",
|
|
" \n",
|
|
" for i in range(n_samples):\n",
|
|
" \n",
|
|
" output,hidden = self.decoder_lstm(x, hidden)\n",
|
|
" output = self.output(output)\n",
|
|
" # output = self.log_softmax(output)\n",
|
|
" # output = output.exp()\n",
|
|
" _, s = torch.topk(output, 1)\n",
|
|
" idx_sample.append(s.item())\n",
|
|
" x = s.squeeze(0)\n",
|
|
"\n",
|
|
" w_sample = [self.dictionary.get_i2w()[str(idx)] for idx in idx_sample]\n",
|
|
" w_sample = \" \".join(w_sample)\n",
|
|
"\n",
|
|
" return w_sample\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_batch(batch):\n",
|
|
" sentences = batch[\"input\"]\n",
|
|
" target = batch[\"target\"]\n",
|
|
" sentences_length = batch[\"length\"]\n",
|
|
"\n",
|
|
" return sentences, target, sentences_length\n",
|
|
"\n",
|
|
"class Trainer:\n",
|
|
"\n",
|
|
" def __init__(self, train_loader, test_loader, model, loss, optimizer) -> None:\n",
|
|
" self.train_loader = train_loader\n",
|
|
" self.test_loader = test_loader\n",
|
|
" self.model = model\n",
|
|
" self.loss = loss\n",
|
|
" self.optimizer = optimizer\n",
|
|
" self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
" self.interval = 200\n",
|
|
"\n",
|
|
"\n",
|
|
" def train(self, train_losses, epoch, batch_size, clip) -> list: \n",
|
|
" # Initialization of RNN hidden, and cell states.\n",
|
|
" states = self.model.init_hidden(batch_size) \n",
|
|
"\n",
|
|
" for batch_idx, (x, targets) in enumerate(self.train_loader):\n",
|
|
" # Get x to cuda if possible\n",
|
|
" x = x.to(device=device).squeeze(1)\n",
|
|
" targets = targets.to(device=device)\n",
|
|
"\n",
|
|
" # for batch_num, batch in enumerate(self.train_loader): # loop over the data, and jump with step = bptt.\n",
|
|
" # get the labels\n",
|
|
" source, target, source_lengths = get_batch(batch)\n",
|
|
" source = source.to(self.device)\n",
|
|
" target = target.to(self.device)\n",
|
|
"\n",
|
|
"\n",
|
|
" x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n",
|
|
"\n",
|
|
" # detach hidden states\n",
|
|
" states = states[0].detach(), states[1].detach()\n",
|
|
"\n",
|
|
" # compute the loss\n",
|
|
" mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n",
|
|
"\n",
|
|
" train_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n",
|
|
"\n",
|
|
" mloss.backward()\n",
|
|
"\n",
|
|
" torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)\n",
|
|
"\n",
|
|
" self.optimizer.step()\n",
|
|
"\n",
|
|
" self.optimizer.zero_grad()\n",
|
|
"\n",
|
|
"\n",
|
|
" if batch_num % self.interval == 0 and batch_num > 0:\n",
|
|
" \n",
|
|
" print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n",
|
|
" epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n",
|
|
"\n",
|
|
" return train_losses\n",
|
|
"\n",
|
|
" def test(self, test_losses, epoch, batch_size) -> list:\n",
|
|
"\n",
|
|
" with torch.no_grad():\n",
|
|
"\n",
|
|
" states = self.model.init_hidden(batch_size) \n",
|
|
"\n",
|
|
" for batch_num, batch in enumerate(self.test_loader): # loop over the data, and jump with step = bptt.\n",
|
|
" # get the labels\n",
|
|
" source, target, source_lengths = get_batch(batch)\n",
|
|
" source = source.to(self.device)\n",
|
|
" target = target.to(self.device)\n",
|
|
"\n",
|
|
"\n",
|
|
" x_hat_param, mu, log_var, z, states = self.model(source,source_lengths, states)\n",
|
|
"\n",
|
|
" # detach hidden states\n",
|
|
" states = states[0].detach(), states[1].detach()\n",
|
|
"\n",
|
|
" # compute the loss\n",
|
|
" mloss, KL_loss, recon_loss = self.loss(mu = mu, log_var = log_var, z = z, x_hat_param = x_hat_param , x = target)\n",
|
|
"\n",
|
|
" test_losses.append((mloss , KL_loss.item(), recon_loss.item()))\n",
|
|
"\n",
|
|
" # Statistics.\n",
|
|
" # if batch_num % 20 ==0:\n",
|
|
" # print('| epoch {:3d} | elbo_loss {:5.6f} | kl_loss {:5.6f} | recons_loss {:5.6f} '.format(\n",
|
|
" # epoch, mloss.item(), KL_loss.item(), recon_loss.item()))\n",
|
|
"\n",
|
|
" return test_losses\n",
|
|
"\n",
|
|
"vae = LSTM_VAE(input_size, output_size, hidden_size, 16, 1, device).to(device)\n",
|
|
"\n",
|
|
"vae_loss = VAE_Loss()\n",
|
|
"optimizer = torch.optim.Adam(vae.parameters(), lr= learning_rate)\n",
|
|
"\n",
|
|
"trainer = Trainer(loader_train, loader_test, vae, vae_loss, optimizer)\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|