Refactor Trajectron++ to work as module and expose training command
This commit is contained in:
parent
4fa3ce95ee
commit
51d6157af9
5 changed files with 574 additions and 393 deletions
931
poetry.lock
generated
931
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,9 +1,17 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "Trajectron-plus-plus"
|
name = "Trajectron-plus-plus"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
description = "Predict trajectories for anomaly detection"
|
description = "This repository contains the code for Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution)."
|
||||||
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
|
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
|
packages = [
|
||||||
|
{ include = "trajectron" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
trajectron_train = "trajectron.train:main"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9,<3.12"
|
python = "^3.9,<3.12"
|
||||||
|
@ -28,8 +36,6 @@ matplotlib = "^3.5"
|
||||||
#scikit-learn = "0.22.1"
|
#scikit-learn = "0.22.1"
|
||||||
#filterpy = "^1.4.5"
|
#filterpy = "^1.4.5"
|
||||||
tqdm = "^4.65.0"
|
tqdm = "^4.65.0"
|
||||||
#ipywidgets = "^8.0.6"
|
|
||||||
#deep-sort-realtime = "^1.3.2"
|
|
||||||
scipy = "^1.11.3"
|
scipy = "^1.11.3"
|
||||||
pandas = "^2.1.1"
|
pandas = "^2.1.1"
|
||||||
orjson = "^3.9.7"
|
orjson = "^3.9.7"
|
||||||
|
@ -41,7 +47,7 @@ notebook = "^7.0.4"
|
||||||
scikit-learn = "^1.3.1"
|
scikit-learn = "^1.3.1"
|
||||||
seaborn = "^0.13.0"
|
seaborn = "^0.13.0"
|
||||||
setuptools = "^68.2.2"
|
setuptools = "^68.2.2"
|
||||||
tensorboard = "1.14"
|
tensorboard = "^2.11"
|
||||||
tensorboardx = "^2.6.2.2"
|
tensorboardx = "^2.6.2.2"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--conf",
|
parser.add_argument("--conf",
|
||||||
help="path to json config file for hyperparameters",
|
help="path to json config file for hyperparameters",
|
||||||
type=str,
|
type=str,
|
||||||
default='../config/config.json')
|
default='./config/config.json')
|
||||||
|
|
||||||
parser.add_argument("--debug",
|
parser.add_argument("--debug",
|
||||||
help="disable all disk writing processes.",
|
help="disable all disk writing processes.",
|
||||||
|
@ -97,7 +97,7 @@ parser.add_argument('--no_edge_encoding',
|
||||||
parser.add_argument("--data_dir",
|
parser.add_argument("--data_dir",
|
||||||
help="what dir to look in for data",
|
help="what dir to look in for data",
|
||||||
type=str,
|
type=str,
|
||||||
default='../experiments/processed')
|
default='./experiments/processed')
|
||||||
|
|
||||||
parser.add_argument("--train_data_dict",
|
parser.add_argument("--train_data_dict",
|
||||||
help="what file to load for training data",
|
help="what file to load for training data",
|
||||||
|
|
|
@ -3,7 +3,7 @@ from scipy.interpolate import RectBivariateSpline
|
||||||
from scipy.ndimage import binary_dilation
|
from scipy.ndimage import binary_dilation
|
||||||
from scipy.stats import gaussian_kde
|
from scipy.stats import gaussian_kde
|
||||||
from trajectron.utils import prediction_output_to_trajectories
|
from trajectron.utils import prediction_output_to_trajectories
|
||||||
import trajectron.visualization
|
import trajectron.visualization as visualization
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,14 +9,14 @@ import random
|
||||||
import pathlib
|
import pathlib
|
||||||
import warnings
|
import warnings
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import visualization
|
import trajectron.visualization as visualization
|
||||||
import evaluation
|
import trajectron.evaluation as evaluation
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from argument_parser import args
|
from trajectron.argument_parser import args
|
||||||
from model.trajectron import Trajectron
|
from trajectron.model.trajectron import Trajectron
|
||||||
from model.model_registrar import ModelRegistrar
|
from trajectron.model.model_registrar import ModelRegistrar
|
||||||
from model.model_utils import cyclical_lr
|
from trajectron.model.model_utils import cyclical_lr
|
||||||
from model.dataset import EnvironmentDataset, collate
|
from trajectron.model.dataset import EnvironmentDataset, collate
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue