From 4fa3ce95ee6b6ec2ec616133f81bd18cd9322de7 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Mon, 9 Oct 2023 21:06:36 +0200 Subject: [PATCH] Trajectron as a module while supporting old models --- trajectron/test_online.py => test_online.py | 11 +++++------ trajectron/model/model_registrar.py | 16 ++++++++++++++-- trajectron/model/online/online_mgcvae.py | 14 +++++++------- trajectron/model/online/online_trajectron.py | 8 ++++---- 4 files changed, 30 insertions(+), 19 deletions(-) rename trajectron/test_online.py => test_online.py (97%) diff --git a/trajectron/test_online.py b/test_online.py similarity index 97% rename from trajectron/test_online.py rename to test_online.py index 3e6cae7..caebdba 100644 --- a/trajectron/test_online.py +++ b/test_online.py @@ -5,13 +5,12 @@ import torch import dill import random import pathlib -import evaluation import numpy as np -import visualization as vis -from argument_parser import args -from model.online.online_trajectron import OnlineTrajectron -from model.model_registrar import ModelRegistrar -from environment import Environment, Scene +import trajectron.visualization as vis +from trajectron.argument_parser import args +from trajectron.model.online.online_trajectron import OnlineTrajectron +from trajectron.model.model_registrar import ModelRegistrar +from trajectron.environment import Environment, Scene import matplotlib.pyplot as plt if not torch.cuda.is_available() or args.device == 'cpu': diff --git a/trajectron/model/model_registrar.py b/trajectron/model/model_registrar.py index 111a8ab..24d2567 100644 --- a/trajectron/model/model_registrar.py +++ b/trajectron/model/model_registrar.py @@ -1,11 +1,23 @@ import os import torch import torch.nn as nn - +import pickle def get_model_device(model): return next(model.parameters()).device +class PickleModuleCompatibility: + ''' + Migrating Trajectron++ to a module structure + while maintaining compatibility with models generated + before the migration + ''' + class Unpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'model' or module[:6] == 'model.': + module = 'trajectron.' + module + return super().find_class(module, name) + class ModelRegistrar(nn.Module): def __init__(self, model_dir, device): @@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module): print('') print('Loading from ' + save_path) - self.model_dict = torch.load(save_path, map_location=self.device) + self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility) print('Loaded!') print('') diff --git a/trajectron/model/online/online_mgcvae.py b/trajectron/model/online/online_mgcvae.py index c614c37..1617115 100644 --- a/trajectron/model/online/online_mgcvae.py +++ b/trajectron/model/online/online_mgcvae.py @@ -4,13 +4,13 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import defaultdict, Counter -from model.components import * -from model.model_utils import * -from model.dataset import get_relative_robot_traj -import model.dynamics as dynamic_module -from model.mgcvae import MultimodalGenerativeCVAE -from environment.scene_graph import DirectedEdge -from environment.node_type import NodeType +from trajectron.model.components import * +from trajectron.model.model_utils import * +from trajectron.model.dataset import get_relative_robot_traj +import trajectron.model.dynamics as dynamic_module +from trajectron.model.mgcvae import MultimodalGenerativeCVAE +from trajectron.environment.scene_graph import DirectedEdge +from trajectron.environment.node_type import NodeType class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): diff --git a/trajectron/model/online/online_trajectron.py b/trajectron/model/online/online_trajectron.py index f1c5063..846228b 100644 --- a/trajectron/model/online/online_trajectron.py +++ b/trajectron/model/online/online_trajectron.py @@ -1,10 +1,10 @@ import torch import numpy as np from collections import Counter -from model.trajectron import Trajectron -from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE -from model.model_utils import ModeKeys -from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of +from trajectron.model.trajectron import Trajectron +from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE +from trajectron.model.model_utils import ModeKeys +from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of class OnlineTrajectron(Trajectron):