diff --git a/trap/prediction_server.py b/trap/prediction_server.py index 84b9178..d2b178f 100644 --- a/trap/prediction_server.py +++ b/trap/prediction_server.py @@ -192,8 +192,12 @@ class PredictionServer: eval_scene = eval_env.scenes[scene_idx] online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep) + # auto-find highest iteration model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device) - model_registrar.load_models(iter_num=100) + model_iterations = pathlib.Path(self.config.model_dir).glob('model_registrar-*.pt') + highest_iter = max([int(p.stem.split('-')[-1]) for p in model_iterations]) + + model_registrar.load_models(iter_num=highest_iter) trajectron = OnlineTrajectron(model_registrar, hyperparams,