Trajectron as a module while supporting old models

This commit is contained in:
Ruben van de Ven 2023-10-09 21:06:36 +02:00
parent dfa1d43f2e
commit 4fa3ce95ee
4 changed files with 30 additions and 19 deletions

View file

@ -5,13 +5,12 @@ import torch
import dill import dill
import random import random
import pathlib import pathlib
import evaluation
import numpy as np import numpy as np
import visualization as vis import trajectron.visualization as vis
from argument_parser import args from trajectron.argument_parser import args
from model.online.online_trajectron import OnlineTrajectron from trajectron.model.online.online_trajectron import OnlineTrajectron
from model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from environment import Environment, Scene from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if not torch.cuda.is_available() or args.device == 'cpu': if not torch.cuda.is_available() or args.device == 'cpu':

View file

@ -1,11 +1,23 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import pickle
def get_model_device(model): def get_model_device(model):
return next(model.parameters()).device return next(model.parameters()).device
class PickleModuleCompatibility:
'''
Migrating Trajectron++ to a module structure
while maintaining compatibility with models generated
before the migration
'''
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'model' or module[:6] == 'model.':
module = 'trajectron.' + module
return super().find_class(module, name)
class ModelRegistrar(nn.Module): class ModelRegistrar(nn.Module):
def __init__(self, model_dir, device): def __init__(self, model_dir, device):
@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module):
print('') print('')
print('Loading from ' + save_path) print('Loading from ' + save_path)
self.model_dict = torch.load(save_path, map_location=self.device) self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
print('Loaded!') print('Loaded!')
print('') print('')

View file

@ -4,13 +4,13 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from collections import defaultdict, Counter from collections import defaultdict, Counter
from model.components import * from trajectron.model.components import *
from model.model_utils import * from trajectron.model.model_utils import *
from model.dataset import get_relative_robot_traj from trajectron.model.dataset import get_relative_robot_traj
import model.dynamics as dynamic_module import trajectron.model.dynamics as dynamic_module
from model.mgcvae import MultimodalGenerativeCVAE from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from environment.scene_graph import DirectedEdge from trajectron.environment.scene_graph import DirectedEdge
from environment.node_type import NodeType from trajectron.environment.node_type import NodeType
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):

View file

@ -1,10 +1,10 @@
import torch import torch
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from model.trajectron import Trajectron from trajectron.model.trajectron import Trajectron
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from model.model_utils import ModeKeys from trajectron.model.model_utils import ModeKeys
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
class OnlineTrajectron(Trajectron): class OnlineTrajectron(Trajectron):