Work around bug with map encoding
This commit is contained in:
parent
e5151e1a32
commit
dd64170ac3
1 changed files with 1 additions and 1 deletions
|
@ -209,7 +209,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
|||
if self.node not in maps:
|
||||
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
||||
me_params = self.hyperparams['map_encoder'][self.node_type]
|
||||
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
|
||||
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])).to(self.TD['node_history_encoded'].device)
|
||||
else:
|
||||
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
||||
(mode == ModeKeys.TRAIN))
|
||||
|
|
Loading…
Reference in a new issue