diff --git a/trap/prediction_server.py b/trap/prediction_server.py index 275a451..b89bd20 100644 --- a/trap/prediction_server.py +++ b/trap/prediction_server.py @@ -326,13 +326,18 @@ class PredictionServer: start = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py + + # in the OnlineMultimodalGenerativeCVAE (see trajectron.model.online_mgcvae.py) each node's distribution + # is put stored in self.latent.p_dist by OnlineMultimodalGenerativeCVAE.p_z_x(). Type: torch.distributions.OneHotCategorical + # Later sampling in discrete_latent.py: DiscreteLatent.sample_p() dists, preds = trajectron.incremental_forward(input_dict, maps, prediction_horizon=self.config.prediction_horizon, # TODO: make variable num_samples=self.config.num_samples, # TODO: make variable - full_dist=self.config.full_dist, - gmm_mode=self.config.gmm_mode, - z_mode=self.config.z_mode) + full_dist=self.config.full_dist, # "The model’s full sampled output, where z and y are sampled sequentially" + gmm_mode=self.config.gmm_mode, # "If True: The mode of the Gaussian Mixture Model (GMM) is sampled (see trajectron.model.mgcvae.py)" + z_mode=self.config.z_mode # "Predictions from the model’s most-likely high-level latent behavior mode" (see trajecton.models.components.discrete_latent:sample_p(most_likely_z=z_mode)) + ) end = time.time() logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start, 1. / (end - start), len(trajectron.nodes),