Some comments on parameters
This commit is contained in:
parent
531d61b69a
commit
53c18d9a7b
1 changed files with 8 additions and 3 deletions
|
@ -326,13 +326,18 @@ class PredictionServer:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
|
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
|
||||||
|
|
||||||
|
# in the OnlineMultimodalGenerativeCVAE (see trajectron.model.online_mgcvae.py) each node's distribution
|
||||||
|
# is put stored in self.latent.p_dist by OnlineMultimodalGenerativeCVAE.p_z_x(). Type: torch.distributions.OneHotCategorical
|
||||||
|
# Later sampling in discrete_latent.py: DiscreteLatent.sample_p()
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=self.config.prediction_horizon, # TODO: make variable
|
prediction_horizon=self.config.prediction_horizon, # TODO: make variable
|
||||||
num_samples=self.config.num_samples, # TODO: make variable
|
num_samples=self.config.num_samples, # TODO: make variable
|
||||||
full_dist=self.config.full_dist,
|
full_dist=self.config.full_dist, # "The model’s full sampled output, where z and y are sampled sequentially"
|
||||||
gmm_mode=self.config.gmm_mode,
|
gmm_mode=self.config.gmm_mode, # "If True: The mode of the Gaussian Mixture Model (GMM) is sampled (see trajectron.model.mgcvae.py)"
|
||||||
z_mode=self.config.z_mode)
|
z_mode=self.config.z_mode # "Predictions from the model’s most-likely high-level latent behavior mode" (see trajecton.models.components.discrete_latent:sample_p(most_likely_z=z_mode))
|
||||||
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
||||||
1. / (end - start), len(trajectron.nodes),
|
1. / (end - start), len(trajectron.nodes),
|
||||||
|
|
Loading…
Reference in a new issue