From dd64170ac37ac56ce6476667b82355ab2d356a16 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Sun, 29 Dec 2024 20:37:47 +0100 Subject: [PATCH] Work around bug with map encoding --- trajectron/model/online/online_mgcvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trajectron/model/online/online_mgcvae.py b/trajectron/model/online/online_mgcvae.py index 1617115..91b13a5 100644 --- a/trajectron/model/online/online_mgcvae.py +++ b/trajectron/model/online/online_mgcvae.py @@ -209,7 +209,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): if self.node not in maps: # This means the node was removed (it is only being kept around because of the edge removal filter). me_params = self.hyperparams['map_encoder'][self.node_type] - self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])) + self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])).to(self.TD['node_history_encoded'].device) else: encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1., (mode == ModeKeys.TRAIN))