diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index b352f5608479..4bff7ba13119 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -39,7 +39,8 @@ def __init__(self, clip_param=0.1, vf_clip_param=0.1, vf_loss_coeff=1.0, - use_gae=True): + use_gae=True, + model_config={}): """Constructs the loss for Proximal Policy Objective. Arguments: @@ -70,7 +71,7 @@ def __init__(self, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( @@ -284,7 +285,9 @@ def __init__(self, clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], vf_loss_coeff=self.config["vf_loss_coeff"], - use_gae=self.config["use_gae"]) + use_gae=self.config["use_gae"], + model_config=self.config["model"] + ) LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 026a6c493e5c..790d96ad58c4 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -5,6 +5,7 @@ from collections import namedtuple import distutils.version import tensorflow as tf +import tensorflow_probability as tfp import numpy as np from ray.rllib.utils.annotations import override, DeveloperAPI @@ -139,6 +140,40 @@ def _build_sample_op(self): return tf.stack([cat.sample() for cat in self.cats], axis=1) +class MultiVariateDiagGaussian(ActionDistribution): + """ + Action distribution where each vector element is a gaussian with + its independent mean and correlated std. + """ + def __init__(self, inputs): + mean, log_std = tf.split(inputs, 2, axis=1) + std = tf.exp(log_std) + self.distribution = tfp.distributions.MultivariateNormalDiag( + loc=mean, scale_diag=std) + ActionDistribution.__init__(self, inputs) + + @override(ActionDistribution) + def logp(self, x): + return self.distribution.log_prob(x) + + @override(ActionDistribution) + def kl(self, other): + if not isinstance(other, MultiVariateDiagGaussian): + raise TypeError( + "Argument other expected type MultiVariateDiagGaussian. " + "Received type {}.".format(type(other)) + ) + return self.distribution.kl_divergence(other.distribution) + + @override(ActionDistribution) + def entropy(self): + return self.distribution.entropy() + + @override(ActionDistribution) + def _build_sample_op(self): + return self.distribution.sample() + + class DiagGaussian(ActionDistribution): """Action distribution where each vector element is a gaussian. diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 776773552df1..667d336db486 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -14,7 +14,8 @@ from ray.rllib.models.extra_spaces import Simplex from ray.rllib.models.action_dist import (Categorical, MultiCategorical, Deterministic, DiagGaussian, - MultiActionDistribution, Dirichlet) + MultiActionDistribution, Dirichlet, + MultiVariateDiagGaussian) from ray.rllib.models.torch_action_dist import (TorchCategorical, TorchDiagGaussian) from ray.rllib.models.preprocessors import get_preprocessor @@ -114,7 +115,17 @@ def get_action_dist(action_space, config, dist_type=None, torch=False): "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") if dist_type is None: - dist = TorchDiagGaussian if torch else DiagGaussian + if torch: + dist = TorchDiagGaussian + else: + custom_options = config.get("custom_options") + if custom_options is None: + dist = DiagGaussian + else: + if custom_options.get("use_multi_variate_normal_diag") is None: + dist = DiagGaussian + else: + dist = MultiVariateDiagGaussian if config.get("squash_to_range"): raise ValueError( "The squash_to_range option is deprecated. See the "