Change imports to support usage as module

Note that this does require rerunning the `process_data.py` scripts in the example folders so that the dill-files are updated.
This commit is contained in:
Ruben van de Ven 2023-10-09 20:27:29 +02:00
parent f4d860907c
commit dfa1d43f2e
13 changed files with 28 additions and 28 deletions

View file

@ -7,11 +7,11 @@ import torch
import numpy as np
import pandas as pd
sys.path.append("../../trajectron")
sys.path.append("../../")
from tqdm import tqdm
from model.model_registrar import ModelRegistrar
from model.trajectron import Trajectron
import evaluation
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.model.trajectron import Trajectron
import trajectron.evaluation as evaluation
seed = 0
np.random.seed(seed)

View file

@ -4,10 +4,10 @@ import numpy as np
import pandas as pd
import dill
sys.path.append("../../trajectron")
from environment import Environment, Scene, Node
from utils import maybe_makedirs
from environment import derivative_of
sys.path.append("../../")
from trajectron.environment import Environment, Scene, Node
from trajectron.utils import maybe_makedirs
from trajectron.environment import derivative_of
desired_max_time = 100
pred_indices = [2, 3]

View file

@ -1,6 +1,6 @@
import torch
import numpy as np
from model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop
from trajectron.model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop
class Map(object):

View file

@ -1,7 +1,7 @@
import random
import numpy as np
import pandas as pd
from environment import DoubleHeaderNumpyArray
from trajectron.environment import DoubleHeaderNumpyArray
from ncls import NCLS

View file

@ -2,8 +2,8 @@ import numpy as np
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde
from utils import prediction_output_to_trajectories
import visualization
from trajectron.utils import prediction_output_to_trajectories
import trajectron.visualization
from matplotlib import pyplot as plt

View file

@ -1,7 +1,7 @@
import torch
import torch.distributions as td
import numpy as np
from model.model_utils import ModeKeys
from trajectron.model.model_utils import ModeKeys
class DiscreteLatent(object):

View file

@ -1,7 +1,7 @@
import torch
import torch.distributions as td
import numpy as np
from model.model_utils import to_one_hot
from trajectron.model.model_utils import to_one_hot
class GMM2D(td.Distribution):

View file

@ -1,4 +1,4 @@
from model.dynamics import Dynamic
from trajectron.model.dynamics import Dynamic
class Linear(Dynamic):

View file

@ -1,7 +1,7 @@
import torch
from model.dynamics import Dynamic
from utils import block_diag
from model.components import GMM2D
from trajectron.model.dynamics import Dynamic
from trajectron.utils import block_diag
from trajectron.model.components import GMM2D
class SingleIntegrator(Dynamic):

View file

@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from model.dynamics import Dynamic
from utils import block_diag
from model.components import GMM2D
from trajectron.model.dynamics import Dynamic
from trajectron.utils import block_diag
from trajectron.model.components import GMM2D
class Unicycle(Dynamic):

View file

@ -2,10 +2,10 @@ import warnings
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from model.components import *
from model.model_utils import *
import model.dynamics as dynamic_module
from environment.scene_graph import DirectedEdge
from trajectron.model.components import *
from trajectron.model.model_utils import *
import trajectron.model.dynamics as dynamic_module
from trajectron.environment.scene_graph import DirectedEdge
class MultimodalGenerativeCVAE(object):

View file

@ -1,7 +1,7 @@
import torch
import numpy as np
from model.mgcvae import MultimodalGenerativeCVAE
from model.dataset import get_timesteps_data, restore
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from trajectron.model.dataset import get_timesteps_data, restore
class Trajectron(object):

View file

@ -1,4 +1,4 @@
from utils import prediction_output_to_trajectories
from trajectron.utils import prediction_output_to_trajectories
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib.patches as patches