diff --git a/rllib/agents/cql/__init__.py b/rllib/agents/cql/__init__.py index 0945f4e37ddc..154804137c51 100644 --- a/rllib/agents/cql/__init__.py +++ b/rllib/agents/cql/__init__.py @@ -1,20 +1,8 @@ -from ray.rllib.agents.cql.cql_apex_sac import CQLApexSACTrainer, CQLAPEXSAC_DEFAULT_CONFIG -from ray.rllib.agents.cql.cql_dqn import CQLDQNTrainer, CQLDQN_DEFAULT_CONFIG -from ray.rllib.agents.cql.cql_sac import CQLSACTrainer, CQLSAC_DEFAULT_CONFIG -from ray.rllib.agents.cql.cql_sac_torch_policy import CQLSACTorchPolicy -from ray.rllib.agents.cql.cql_sac_tf_policy import CQLSACTFPolicy -from ray.rllib.agents.cql.cql_dqn_tf_policy import CQLDQNTFPolicy -from ray.rllib.agents.cql.cql_sac_tf_model import CQLSACTFModel +from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG +from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy __all__ = [ - "CQLAPEXSAC_DEFAULT_CONFIG", - "CQLDQN_DEFAULT_CONFIG", - "CQLSAC_DEFAULT_CONFIG", - "CQLDQNTFPolicy", - "CQLSACTFPolicy", - "CQLSACTFModel", - "CQLSACTorchPolicy", - "CQLApexSACTrainer", - "CQLDQNTrainer", - "CQLSACTrainer", + "CQL_DEFAULT_CONFIG", + "CQLTorchPolicy", + "CQLTrainer", ] diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py new file mode 100644 index 000000000000..562a502aece5 --- /dev/null +++ b/rllib/agents/cql/cql.py @@ -0,0 +1,288 @@ +import logging +import numpy as np +from typing import Type + +from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy +from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy +from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG +from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.execution.replay_ops import Replay +from ray.rllib.execution.train_ops import ( + multi_gpu_train_one_step, + MultiGPUTrainOneStep, + train_one_step, + TrainOneStep, + UpdateTargetNetwork, +) +from ray.rllib.offline.shuffled_input import ShuffledInput +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics import ( + LAST_TARGET_UPDATE_TS, + NUM_AGENT_STEPS_TRAINED, + NUM_ENV_STEPS_TRAINED, + NUM_TARGET_UPDATES, + TARGET_NET_UPDATE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer +from ray.rllib.utils.typing import ResultDict, TrainerConfigDict + +tf1, tf, tfv = try_import_tf() +tfp = try_import_tfp() +logger = logging.getLogger(__name__) + +# fmt: off +# __sphinx_doc_begin__ +CQL_DEFAULT_CONFIG = merge_dicts( + SAC_CONFIG, { + # You should override this to point to an offline dataset. + "input": "sampler", + # Switch off off-policy evaluation. + "input_evaluation": [], + # Number of iterations with Behavior Cloning Pretraining. + "bc_iters": 20000, + # CQL loss temperature. + "temperature": 1.0, + # Number of actions to sample for CQL loss. + "num_actions": 10, + # Whether to use the Lagrangian for Alpha Prime (in CQL loss). + "lagrangian": False, + # Lagrangian threshold. + "lagrangian_thresh": 5.0, + # Min Q weight multiplier. + "min_q_weight": 5.0, + "replay_buffer_config": { + "_enable_replay_buffer_api": False, + "type": "MultiAgentReplayBuffer", + # Replay buffer should be larger or equal the size of the offline + # dataset. + "capacity": int(1e6), + }, + # Reporting: As CQL is offline (no sampling steps), we need to limit an + # iteration's reporting by the number of steps trained (not sampled). + "min_sample_timesteps_per_reporting": 0, + "min_train_timesteps_per_reporting": 100, + + # Use the Trainer's `training_iteration` function instead of `execution_plan`. + "_disable_execution_plan_api": True, + + # Deprecated keys. + # Use `replay_buffer_config.capacity` instead. + "buffer_size": DEPRECATED_VALUE, + # Use `min_sample_timesteps_per_reporting` and + # `min_train_timesteps_per_reporting` instead. + "timesteps_per_iteration": DEPRECATED_VALUE, + }) +# __sphinx_doc_end__ +# fmt: on + + +class CQLTrainer(SACTrainer): + """CQL (derived from SAC).""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Add the entire dataset to Replay Buffer (global variable) + reader = self.workers.local_worker().input_reader + + # For d4rl, add the D4RLReaders' dataset to the buffer. + if isinstance(self.config["input"], str) and "d4rl" in self.config["input"]: + dataset = reader.dataset + self.local_replay_buffer.add(dataset) + # For a list of files, add each file's entire content to the buffer. + elif isinstance(reader, ShuffledInput): + num_batches = 0 + total_timesteps = 0 + for batch in reader.child.read_all_files(): + num_batches += 1 + total_timesteps += len(batch) + # Add NEXT_OBS if not available. This is slightly hacked + # as for the very last time step, we will use next-obs=zeros + # and therefore force-set DONE=True to avoid this missing + # next-obs to cause learning problems. + if SampleBatch.NEXT_OBS not in batch: + obs = batch[SampleBatch.OBS] + batch[SampleBatch.NEXT_OBS] = np.concatenate( + [obs[1:], np.zeros_like(obs[0:1])] + ) + batch[SampleBatch.DONES][-1] = True + self.local_replay_buffer.add_batch(batch) + print( + f"Loaded {num_batches} batches ({total_timesteps} ts) into the" + " replay buffer, which has capacity " + f"{self.local_replay_buffer.capacity}." + ) + else: + raise ValueError( + "Unknown offline input! config['input'] must either be list of" + " offline files (json) or a D4RL-specific InputReader " + "specifier (e.g. 'd4rl.hopper-medium-v0')." + ) + + @classmethod + @override(SACTrainer) + def get_default_config(cls) -> TrainerConfigDict: + return CQL_DEFAULT_CONFIG + + @override(SACTrainer) + def validate_config(self, config: TrainerConfigDict) -> None: + # First check, whether old `timesteps_per_iteration` is used. If so + # convert right away as for CQL, we must measure in training timesteps, + # never sampling timesteps (CQL does not sample). + if config.get("timesteps_per_iteration", DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="timesteps_per_iteration", + new="min_train_timesteps_per_reporting", + error=False, + ) + config["min_train_timesteps_per_reporting"] = config[ + "timesteps_per_iteration" + ] + config["timesteps_per_iteration"] = DEPRECATED_VALUE + + # Call super's validation method. + super().validate_config(config) + + if config["num_gpus"] > 1: + raise ValueError("`num_gpus` > 1 not yet supported for CQL!") + + # CQL-torch performs the optimizer steps inside the loss function. + # Using the multi-GPU optimizer will therefore not work (see multi-GPU + # check above) and we must use the simple optimizer for now. + if config["simple_optimizer"] is not True and config["framework"] == "torch": + config["simple_optimizer"] = True + + if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None: + logger.warning( + "You need `tensorflow_probability` in order to run CQL! " + "Install it via `pip install tensorflow_probability`. Your " + f"tf.__version__={tf.__version__ if tf else None}." + "Trying to import tfp results in the following error:" + ) + try_import_tfp(error=True) + + @override(SACTrainer) + def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: + if config["framework"] == "torch": + return CQLTorchPolicy + else: + return CQLTFPolicy + + @override(SACTrainer) + def training_iteration(self) -> ResultDict: + + # Sample training batch from replay buffer. + train_batch = self.local_replay_buffer.replay() + + # Old-style replay buffers return None if learning has not started + if not train_batch: + return {} + + # Postprocess batch before we learn on it. + post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) + train_batch = post_fn(train_batch, self.workers, self.config) + + # Learn on training batch. + # Use simple optimizer (only for multi-agent or tf-eager; all other + # cases should use the multi-GPU optimizer, even if only using 1 GPU) + if self.config.get("simple_optimizer") is True: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + # Update replay buffer priorities. + update_priorities_in_replay_buffer( + self.local_replay_buffer, + self.config, + train_batch, + train_results, + ) + + # Update target network every target_network_update_freq steps + cur_ts = self._counters[ + NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED + ] + last_update = self._counters[LAST_TARGET_UPDATE_TS] + if cur_ts - last_update >= self.config["target_network_update_freq"]: + with self._timers[TARGET_NET_UPDATE_TIMER]: + to_update = self.workers.local_worker().get_policies_to_train() + self.workers.local_worker().foreach_policy_to_train( + lambda p, pid: pid in to_update and p.update_target() + ) + self._counters[NUM_TARGET_UPDATES] += 1 + self._counters[LAST_TARGET_UPDATE_TS] = cur_ts + + # Update remote workers's weights after learning on local worker + if self.workers.remote_workers(): + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.workers.sync_weights() + + # Return all collected metrics for the iteration. + return train_results + + @staticmethod + @override(SACTrainer) + def execution_plan(workers, config, **kwargs): + assert ( + "local_replay_buffer" in kwargs + ), "CQL execution plan requires a local replay buffer." + + local_replay_buffer = kwargs["local_replay_buffer"] + + def update_prio(item): + samples, info_dict = item + if config.get("prioritized_replay"): + prio_dict = {} + for policy_id, info in info_dict.items(): + # TODO(sven): This is currently structured differently for + # torch/tf. Clean up these results/info dicts across + # policies (note: fixing this in torch_policy.py will + # break e.g. DDPPO!). + td_error = info.get( + "td_error", info[LEARNER_STATS_KEY].get("td_error") + ) + samples.policy_batches[policy_id].set_get_interceptor(None) + prio_dict[policy_id] = ( + samples.policy_batches[policy_id].get("batch_indexes"), + td_error, + ) + local_replay_buffer.update_priorities(prio_dict) + return info_dict + + # (2) Read and train on experiences from the replay buffer. Every batch + # returned from the LocalReplay() iterator is passed to TrainOneStep to + # take a SGD step, and then we decide whether to update the target + # network. + post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) + + if config["simple_optimizer"]: + train_step_op = TrainOneStep(workers) + else: + train_step_op = MultiGPUTrainOneStep( + workers=workers, + sgd_minibatch_size=config["train_batch_size"], + num_sgd_iter=1, + num_gpus=config["num_gpus"], + _fake_gpus=config["_fake_gpus"], + ) + + train_op = ( + Replay(local_buffer=local_replay_buffer) + .for_each(lambda x: post_fn(x, workers, config)) + .for_each(train_step_op) + .for_each(update_prio) + .for_each( + UpdateTargetNetwork(workers, config["target_network_update_freq"]) + ) + ) + + return StandardMetricsReporting( + train_op, workers, config, by_steps_trained=True + ) diff --git a/rllib/agents/cql/cql_apex_sac.py b/rllib/agents/cql/cql_apex_sac.py deleted file mode 100644 index 21831d4d3130..000000000000 --- a/rllib/agents/cql/cql_apex_sac.py +++ /dev/null @@ -1,52 +0,0 @@ -from ray.rllib.agents.dqn.apex import apex_execution_plan -from ray.rllib.agents.cql.cql_sac import CQLSAC_DEFAULT_CONFIG, CQLSACTrainer - -# yapf: disable -# __sphinx_doc_begin__ - -CQLAPEXSAC_DEFAULT_CONFIG = CQLSACTrainer.merge_trainer_configs( - CQLSAC_DEFAULT_CONFIG, # see also the options in sac.py, which are also supported - { - "optimizer": { - "max_weight_sync_delay": 400, - "num_replay_buffer_shards": 4, - "debug": False, - }, - "n_step": 1, - "num_gpus": 0, - "num_workers": 32, - "buffer_size": 200000, - "learning_starts": 5000, - "train_batch_size": 512, - "rollout_fragment_length": 50, - "target_network_update_freq": 0, - "timesteps_per_iteration": 1000, - "exploration_config": {"type": "StochasticSampling"}, - "worker_side_prioritization": True, - "min_iter_time_s": 10, - # We need to implement a version of Prioritized Replay for SAC - # that takes into account the policy entropy term of the loss. - # And for CQL_SAC, we need to also consider the CQL regularizer - "prioritized_replay": False, - # If set, this will fix the ratio of sampled to replayed timesteps. - # Otherwise, replay will proceed as fast as possible. - "training_intensity": None, - # Which mode to use in the ParallelRollouts operator used to collect - # samples. For more details check the operator in rollout_ops module. - "parallel_rollouts_mode": "async", - # This only applies if async mode is used (above config setting). - # Controls the max number of async requests in flight per actor - "parallel_rollouts_num_async": 2, - }, -) - - -# __sphinx_doc_end__ -# yapf: enable - - -CQLApexSACTrainer = CQLSACTrainer.with_updates( - name="CQL_APEX_SAC", - default_config=CQLAPEXSAC_DEFAULT_CONFIG, - execution_plan=apex_execution_plan, -) diff --git a/rllib/agents/cql/cql_dqn.py b/rllib/agents/cql/cql_dqn.py deleted file mode 100644 index 295597112916..000000000000 --- a/rllib/agents/cql/cql_dqn.py +++ /dev/null @@ -1,46 +0,0 @@ -"""CQL (derived from DQN). -""" -from typing import Optional, Type - -from ray.rllib.agents.cql.cql_dqn_tf_policy import CQLDQNTFPolicy -from ray.rllib.agents.dqn.dqn import DQNTrainer, \ - DEFAULT_CONFIG as DQN_CONFIG -from ray.rllib.utils.typing import TrainerConfigDict -from ray.rllib.policy.policy import Policy -from ray.rllib.utils import merge_dicts - -# yapf: disable -# __sphinx_doc_begin__ -CQLDQN_DEFAULT_CONFIG = merge_dicts( - DQN_CONFIG, { - # You should override this to point to an offline dataset. - "input": "sampler", - # Offline RL does not need IS estimators - "input_evaluation": [], - # Min Q Weight multiplier - "min_q_weight": 1.0, - # The default value is set as the same of DQN which is good for - # online training. For offline training we could start to optimize - # the models right away. - "learning_starts": 1000, - # Replay Buffer should be size of offline dataset for fastest - # training - "buffer_size": 1000000, - }) -# __sphinx_doc_end__ -# yapf: enable - - -def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: - if config["framework"] == "torch": - raise ValueError("Torch CQL not implemented yet!") - else: - return CQLDQNTFPolicy - - -CQLDQNTrainer = DQNTrainer.with_updates( - name="CQL_DQN", - default_config=CQLDQN_DEFAULT_CONFIG, - default_policy=CQLDQNTFPolicy, - get_policy_class=get_policy_class, -) diff --git a/rllib/agents/cql/cql_dqn_tf_policy.py b/rllib/agents/cql/cql_dqn_tf_policy.py deleted file mode 100644 index 4b17d435d36f..000000000000 --- a/rllib/agents/cql/cql_dqn_tf_policy.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging - -import ray -from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy, QLoss, build_q_losses -from ray.rllib.utils.framework import try_import_tf, try_import_tfp - -logger = logging.getLogger(__name__) - -tf1, tf, tfv = try_import_tf() -tfp = try_import_tfp() - -class CQLQLoss(QLoss): - def __init__(self, - q_t, - q_t_selected, - q_logits_t_selected, - q_tp1_best, - q_dist_tp1_best, - importance_weights, - rewards, - done_mask, - config, - ): - - super().__init__(q_t, q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best, importance_weights, rewards, done_mask, config) - min_q_weight = config["min_q_weight"] - - dataset_expec = tf.reduce_mean(q_t_selected) - negative_sampling = tf.reduce_mean(tf.reduce_logsumexp(q_t, 1)) - - min_q_loss = (negative_sampling - dataset_expec) - - min_q_loss = min_q_loss * min_q_weight - self.loss = self.loss + min_q_loss - self.stats["cql_loss"] = min_q_loss - - -def build_cql_losses(policy, model, dist_class, train_batch): - return build_q_losses(policy, model, dist_class, train_batch, CQLQLoss) - - -# Build a child class of `TFPolicy`, given the custom functions defined -# above. -CQLDQNTFPolicy = DQNTFPolicy.with_updates( - name="CQLDQNTFPolicy", - get_default_config=lambda: ray.rllib.agents.cql.CQLDQN_DEFAULT_CONFIG, - loss_fn=build_cql_losses, -) diff --git a/rllib/agents/cql/cql_sac.py b/rllib/agents/cql/cql_sac.py deleted file mode 100644 index ed1d7021e57c..000000000000 --- a/rllib/agents/cql/cql_sac.py +++ /dev/null @@ -1,63 +0,0 @@ -"""CQL (derived from SAC). -""" -from typing import Optional, Type - -from ray.rllib.agents.cql.cql_sac_tf_policy import CQLSACTFPolicy -from ray.rllib.agents.sac.sac import SACTrainer, \ - DEFAULT_CONFIG as SAC_CONFIG -from ray.rllib.agents.cql.cql_sac_torch_policy import CQLSACTorchPolicy -from ray.rllib.utils.typing import TrainerConfigDict -from ray.rllib.policy.policy import Policy -from ray.rllib.utils import merge_dicts - -# yapf: disable -# __sphinx_doc_begin__ -CQLSAC_DEFAULT_CONFIG = merge_dicts( - SAC_CONFIG, { - # You should override this to point to an offline dataset. - "input": "sampler", - # Offline RL does not need IS estimators - "input_evaluation": [], - # Number of iterations with Behavior Cloning Pretraining - "bc_iters": 20000, - # CQL Loss Temperature - "temperature": 1.0, - # Num Actions to sample for CQL Loss - "num_actions": 10, - # Whether to use the Langrangian for Alpha Prime (in CQL Loss) - "lagrangian": False, - # Lagrangian Threshold - "lagrangian_thresh": 5.0, - # Min Q Weight multiplier - "min_q_weight": 5.0, - # Initial value to use for the Alpha Prime (in CQL Loss). - "initial_alpha_prime": 1.0, - # The default value is set as the same of SAC which is good for - # online training. For offline training we could start to optimize - # the models right away. - "learning_starts": 1500, - # Replay Buffer should be size of offline dataset for fastest - # training - "buffer_size": 1000000, - # Upper bound for alpha value during the lagrangian constraint - "alpha_upper_bound": 1.0, - # Lower bound for alpha value during the lagrangian constraint - "alpha_lower_bound": 0.0, - }) -# __sphinx_doc_end__ -# yapf: enable - - -def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: - if config["framework"] == "torch": - return CQLSACTorchPolicy - else: - return CQLSACTFPolicy - - -CQLSACTrainer = SACTrainer.with_updates( - name="CQL_SAC", - default_config=CQLSAC_DEFAULT_CONFIG, - default_policy=CQLSACTFPolicy, - get_policy_class=get_policy_class, -) diff --git a/rllib/agents/cql/cql_sac_tf_model.py b/rllib/agents/cql/cql_sac_tf_model.py deleted file mode 100644 index 8bdef0d3c7b8..000000000000 --- a/rllib/agents/cql/cql_sac_tf_model.py +++ /dev/null @@ -1,78 +0,0 @@ -import gym -import numpy as np -from typing import Optional - -from ray.rllib.agents.sac.sac_tf_model import SACTFModel -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.typing import ModelConfigDict - -tf1, tf, tfv = try_import_tf() - - -class CQLSACTFModel(SACTFModel): - """Extension of SACTFModel for CQL. - - To customize, do one of the following: - - sub-class CQLTFModel and override one or more of its methods. - - Use CQL's `Q_model` and `policy_model` keys to tweak the default model - behaviors (e.g. fcnet_hiddens, conv_filters, etc..). - - Use CQL's `Q_model->custom_model` and `policy_model->custom_model` keys - to specify your own custom Q-model(s) and policy-models, which will be - created within this CQLTFModel (see `build_policy_model` and - `build_q_model`. - - Note: It is not recommended to override the `forward` method for CQL. This - would lead to shared weights (between policy and Q-nets), which will then - not be optimized by either of the critic- or actor-optimizers! - - Data flow: - `obs` -> forward() (should stay a noop method!) -> `model_out` - `model_out` -> get_policy_output() -> pi(actions|obs) - `model_out`, `actions` -> get_q_values() -> Q(s, a) - `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a) - """ - - def __init__(self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: Optional[int], - model_config: ModelConfigDict, - name: str, - policy_model_config: ModelConfigDict = None, - q_model_config: ModelConfigDict = None, - twin_q: bool = False, - initial_alpha: float = 1.0, - target_entropy: Optional[float] = None, - lagrangian: bool = False, - initial_alpha_prime: float = 1.0): - """Initialize a CQLSACTFModel instance. - - Args: - policy_model_config (ModelConfigDict): The config dict for the - policy network. - q_model_config (ModelConfigDict): The config dict for the - Q-network(s) (2 if twin_q=True). - twin_q (bool): Build twin Q networks (Q-net and target) for more - stable Q-learning. - initial_alpha (float): The initial value for the to-be-optimized - alpha parameter (default: 1.0). - target_entropy (Optional[float]): A target entropy value for - the to-be-optimized alpha parameter. If None, will use the - defaults described in the papers for SAC (and discrete SAC). - lagrangian (bool): Whether to automatically adjust value via - Lagrangian dual gradient descent. - initial_alpha_prime (float): The initial value for the to-be-optimized - alpha_prime parameter (default: 1.0). - - Note that the core layers for forward() are not defined here, this - only defines the layers for the output heads. Those layers for - forward() should be defined in subclasses of CQLModel. - """ - super(CQLSACTFModel, self).__init__(obs_space, action_space, num_outputs, - model_config, name, policy_model_config, - q_model_config, twin_q, initial_alpha, - target_entropy) - if lagrangian: - self.log_alpha_prime = tf.Variable( - np.log(initial_alpha_prime), dtype=tf.float32, name="log_alpha_prime") - self.alpha_prime = tf.exp(self.log_alpha_prime) diff --git a/rllib/agents/cql/cql_sac_tf_policy.py b/rllib/agents/cql/cql_sac_tf_policy.py deleted file mode 100644 index 53d21fea1d3b..000000000000 --- a/rllib/agents/cql/cql_sac_tf_policy.py +++ /dev/null @@ -1,387 +0,0 @@ -""" -TF policy class used for CQL. -""" -from functools import partial - -import numpy as np -import gym -import logging -from typing import Dict, Union, Type, List - -import ray -import ray.experimental.tf_utils -from ray.rllib.agents.cql.cql_sac_tf_model import CQLSACTFModel -from ray.rllib.agents.sac.sac_tf_policy import ActorCriticOptimizerMixin, \ - ComputeTDErrorMixin, TargetNetworkMixin, stats, \ - compute_and_clip_gradients, apply_gradients, SACTFPolicy, sac_actor_critic_loss -from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.tf_action_dist import TFActionDistribution -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.framework import try_import_tf, try_import_tfp -from ray.rllib.utils.typing import TensorType, TrainerConfigDict, LocalOptimizer, \ - ModelGradients - -tf1, tf, tfv = try_import_tf() -tfp = try_import_tfp() - -logger = logging.getLogger(__name__) - - -def build_cql_sac_model(policy: Policy, obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict) -> ModelV2: - """Constructs the necessary ModelV2 for the Policy and returns it. - - Args: - policy (Policy): The TFPolicy that will use the models. - obs_space (gym.spaces.Space): The observation space. - action_space (gym.spaces.Space): The action space. - config (TrainerConfigDict): The CQL trainer's config dict. - - Returns: - ModelV2: The ModelV2 to be used by the Policy. Note: An additional - target model will be created in this function and assigned to - `policy.target_model`. - """ - # With separate state-preprocessor (before obs+action concat). - num_outputs = int(np.product(obs_space.shape)) - - # Force-ignore any additionally provided hidden layer sizes. - # Everything should be configured using CQL_SAC's "Q_model" and "policy_model" - # settings. - policy_model_config = MODEL_DEFAULTS.copy() - policy_model_config.update(config["policy_model"]) - q_model_config = MODEL_DEFAULTS.copy() - q_model_config.update(config["Q_model"]) - - assert config["framework"] != "torch" - default_model_cls = CQLSACTFModel - - model = ModelCatalog.get_model_v2( - obs_space=obs_space, - action_space=action_space, - num_outputs=num_outputs, - model_config=config["model"], - framework=config["framework"], - default_model=default_model_cls, - name="cql_sac_model", - policy_model_config=policy_model_config, - q_model_config=q_model_config, - twin_q=config["twin_q"], - initial_alpha=config["initial_alpha"], - target_entropy=config["target_entropy"], - lagrangian=config["lagrangian"], - initial_alpha_prime=config["initial_alpha_prime"]) - - assert isinstance(model, default_model_cls) - - # Create an exact copy of the model and store it in `policy.target_model`. - # This will be used for tau-synched Q-target models that run behind the - # actual Q-networks and are used for target q-value calculations in the - # loss terms. - policy.target_model = ModelCatalog.get_model_v2( - obs_space=obs_space, - action_space=action_space, - num_outputs=num_outputs, - model_config=config["model"], - framework=config["framework"], - default_model=default_model_cls, - name="target_cql_sac_model", - policy_model_config=policy_model_config, - q_model_config=q_model_config, - twin_q=config["twin_q"], - initial_alpha=config["initial_alpha"], - target_entropy=config["target_entropy"], - lagrangian=config["lagrangian"], - initial_alpha_prime=config["initial_alpha_prime"]) - - assert isinstance(policy.target_model, default_model_cls) - - return model - - -# Returns policy tiled actions and log probabilities for CQL Loss -def policy_actions_repeat(model, action_dist, obs, num_repeat=1): - obs_temp = tf.tile(obs, [num_repeat, 1]) - policy_dist = action_dist(model.get_policy_output(obs_temp), model) - actions = policy_dist.sample() - log_p = tf.expand_dims(policy_dist.logp(actions), -1) - return actions, tf.squeeze(log_p, axis=len(log_p.shape) - 1) - - -def q_values_repeat(model, obs, actions, twin=False): - action_shape = tf.shape(actions)[0] - obs_shape = tf.shape(obs)[0] - num_repeat = action_shape // obs_shape - obs_temp = tf.tile(obs, [num_repeat, 1]) - if twin: - preds = model.get_q_values(obs_temp, actions) - else: - preds = model.get_twin_q_values(obs_temp, actions) - preds = tf.reshape(preds, [obs_shape, num_repeat, 1]) - return preds - - -def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], - train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: - """Constructs the loss for the Soft Actor Critic. - - Args: - policy (Policy): The Policy to calculate the loss for. - model (ModelV2): The Model to calculate the loss for. - dist_class (Type[ActionDistribution]: The action distr. class. - train_batch (SampleBatch): The training data. - - Returns: - Union[TensorType, List[TensorType]]: A single loss tensor or a list - of loss tensors. - """ - # For best performance, turn deterministic off - deterministic = policy.config["_deterministic_loss"] - twin_q = policy.config["twin_q"] - discount = policy.config["gamma"] - - # CQL Parameters - bc_iters = policy.config["bc_iters"] - cql_temp = policy.config["temperature"] - num_actions = policy.config["num_actions"] - min_q_weight = policy.config["min_q_weight"] - use_lagrange = policy.config["lagrangian"] - target_action_gap = policy.config["lagrangian_thresh"] - - obs = train_batch[SampleBatch.CUR_OBS] - actions = train_batch[SampleBatch.ACTIONS] - rewards = train_batch[SampleBatch.REWARDS] - next_obs = train_batch[SampleBatch.NEXT_OBS] - terminals = train_batch[SampleBatch.DONES] - - # Execute SAC Policy as it is - sac_loss_res = sac_actor_critic_loss(policy, model, dist_class, train_batch) - - # CQL Loss (We are using Entropy version of CQL (the best version)) - rand_actions = policy._unif_dist.sample([tf.shape(actions)[0] * num_actions, - actions.shape[-1]]) - curr_actions, curr_logp = policy_actions_repeat(model, policy.action_dist_class, - obs, num_actions) - next_actions, next_logp = policy_actions_repeat(model, policy.action_dist_class, - next_obs, num_actions) - curr_logp = tf.reshape(curr_logp, [tf.shape(actions)[0], num_actions, 1]) - next_logp = tf.reshape(next_logp, [tf.shape(actions)[0], num_actions, 1]) - - q1_rand = q_values_repeat(model, policy.model_out_t, rand_actions) - q1_curr_actions = q_values_repeat(model, policy.model_out_t, curr_actions) - q1_next_actions = q_values_repeat(model, policy.model_out_t, next_actions) - - if twin_q: - q2_rand = q_values_repeat(model, policy.model_out_t, rand_actions, twin=True) - q2_curr_actions = q_values_repeat( - model, policy.model_out_t, curr_actions, twin=True) - q2_next_actions = q_values_repeat( - model, policy.model_out_t, next_actions, twin=True) - - random_density = np.log(0.5**curr_actions.shape[-1].value) - cat_q1 = tf.concat([ - q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), - q1_curr_actions - tf.stop_gradient(curr_logp) - ], 1) - if twin_q: - cat_q2 = tf.concat([ - q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), - q2_curr_actions - tf.stop_gradient(curr_logp) - ], 1) - - min_qf1_loss = tf.reduce_mean(tf.reduce_logsumexp( - cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp - min_qf1_loss = min_qf1_loss - tf.reduce_mean(policy.q_t_selected) * min_q_weight - if twin_q: - min_qf2_loss = tf.reduce_mean(tf.reduce_logsumexp( - cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp - min_qf2_loss = min_qf2_loss - tf.reduce_mean(policy.twin_q_t_selected) * min_q_weight - - if use_lagrange: - alpha_upper_bound = policy.config["alpha_upper_bound"] - alpha_lower_bound = policy.config["alpha_lower_bound"] - alpha_prime = tf.clip_by_value( - tf.exp(model.log_alpha_prime), clip_value_min=alpha_lower_bound, clip_value_max=alpha_upper_bound) - min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) - if twin_q: - min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) - alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) - else: - alpha_prime_loss = -min_qf1_loss - - cql_loss = [min_qf1_loss] - if twin_q: - cql_loss.append(min_qf2_loss) - - policy.critic_loss[0] += min_qf1_loss - if twin_q: - policy.critic_loss[1] += min_qf2_loss - - # Save for stats function. - # CQL Stats - policy.cql_loss = cql_loss - if use_lagrange: - policy.log_alpha_prime_value = model.log_alpha_prime - policy.alpha_prime_value = model.alpha_prime - policy.alpha_prime_loss = alpha_prime_loss - # In a custom apply op we handle the losses separately, but return them - # combined in one loss here. - return sac_loss_res + alpha_prime_loss - else: - return sac_loss_res - - -def cql_compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, - loss: TensorType) -> ModelGradients: - """Gradients computing function (from loss tensor, using local optimizer). - - Note: For CQL, optimizer and loss are ignored b/c we have 1 extra - loss and 1 local optimizer (all stored in policy). - `optimizer` will be used, though, in the tf-eager case b/c it is then a - fake optimizer (OptimizerWrapper) object with a `tape` property to - generate a GradientTape object for gradient recording. - - Args: - policy (Policy): The Policy object that generated the loss tensor and - that holds the given local optimizer. - optimizer (LocalOptimizer): The tf (local) optimizer object to - calculate the gradients with. - loss (TensorType): The loss tensor for which gradients should be - calculated. - - Returns: - ModelGradients: List of the possibly clipped gradients- and variable - tuples. - """ - # Eager: Use GradientTape (which is a property of the `optimizer` object - # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). - grads_and_vars = compute_and_clip_gradients(policy, optimizer, loss) - if policy.config["lagrangian"]: - if policy.config["framework"] in ["tf2", "tfe"]: - tape = optimizer.tape - alpha_prime_vars = [policy.model.log_alpha_prime] - alpha_prime_grads_and_vars = list( - zip(tape.gradient(policy.alpha_prime_loss, alpha_prime_vars), alpha_prime_vars)) - # Tf1.x: Use optimizer.compute_gradients() - else: - alpha_prime_grads_and_vars = policy._alpha_prime_optimizer.compute_gradients( - policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime]) - - # Clip if necessary. - if policy.config["grad_clip"]: - clip_func = partial( - tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) - else: - clip_func = tf.identity - - # Save grads and vars for later use in `build_apply_op`. - policy._alpha_prime_grads_and_vars = [(clip_func(g), v) - for (g, v) in alpha_prime_grads_and_vars - if g is not None] - - grads_and_vars = tuple(list(grads_and_vars) + policy._alpha_prime_grads_and_vars) - - return grads_and_vars - - -def cql_apply_gradients( - policy: Policy, optimizer: LocalOptimizer, - grads_and_vars: ModelGradients) -> Union["tf.Operation", None]: - """Gradients applying function (from list of "grad_and_var" tuples). - - Args: - policy (Policy): The Policy object whose Model(s) the given gradients - should be applied to. - optimizer (LocalOptimizer): The tf (local) optimizer object through - which to apply the gradients. - grads_and_vars (ModelGradients): The list of grad_and_var tuples to - apply via the given optimizer. - - Returns: - Union[tf.Operation, None]: The tf op to be used to run the apply - operation. None for eager mode. - """ - grads_group_ops = apply_gradients(policy, optimizer, grads_and_vars) - if policy.config["lagrangian"]: - # Eager mode -> Just apply and return None. - if policy.config["framework"] in ["tf2", "tfe"]: - policy._alpha_prime_optimizer.apply_gradients( - policy._alpha_prime_grads_and_vars) - # Tf static graph -> Return op. - else: - alpha_prime_apply_ops = policy._alpha_prime_optimizer.apply_gradients( - policy._alpha_prime_grads_and_vars) - grads_group_ops = tf.group([grads_group_ops, alpha_prime_apply_ops]) - - return grads_group_ops - - -def cql_stats(policy: Policy, - train_batch: SampleBatch) -> Dict[str, TensorType]: - cql_dict = stats(policy, train_batch) - cql_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss)) - if policy.config["lagrangian"]: - cql_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value - cql_dict["alpha_prime_value"] = policy.alpha_prime_value - cql_dict["alpha_prime_loss"] = policy.alpha_prime_loss - return cql_dict - - -class CQLActorCriticOptimizerMixin(ActorCriticOptimizerMixin): - def __init__(self, config): - super().__init__(config) - if config["framework"] in ["tf2", "tfe"]: - if config["lagrangian"]: - self._alpha_prime_optimizer = tf.keras.optimizers.Adam( - learning_rate=config["optimization"]["critic_learning_rate"]) - else: - if config["lagrangian"]: - self._alpha_prime_optimizer = tf1.train.AdamOptimizer( - learning_rate=config["optimization"]["critic_learning_rate"]) - - -def cql_setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict): - """Call mixin classes' constructors before Policy's initialization. - - Adds the necessary optimizers to the given Policy. - - Args: - policy (Policy): The Policy object. - obs_space (gym.spaces.Space): The Policy's observation space. - action_space (gym.spaces.Space): The Policy's action space. - config (TrainerConfigDict): The Policy's config. - """ - CQLActorCriticOptimizerMixin.__init__(policy, config) - - -def cql_setup_mid_mixins(policy: Policy, obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict) -> None: - action_low = policy.model.action_space.low[0] - action_high = policy.model.action_space.high[0] - policy._unif_dist = tfp.distributions.Uniform(action_low, action_high, - name = "uniform_rand_actions") - ComputeTDErrorMixin.__init__(policy, cql_loss) - - -# Build a child class of `TFPolicy`, given the custom functions defined -# above. -CQLSACTFPolicy = SACTFPolicy.with_updates( - name="CQLSACTFPolicy", - get_default_config=lambda: ray.rllib.agents.cql.CQLSAC_DEFAULT_CONFIG, - make_model=build_cql_sac_model, - loss_fn=cql_loss, - stats_fn=cql_stats, - gradients_fn=cql_compute_and_clip_gradients, - apply_gradients_fn=cql_apply_gradients, - mixins=[ - TargetNetworkMixin, CQLActorCriticOptimizerMixin, ComputeTDErrorMixin - ], - before_init=cql_setup_early_mixins, - before_loss_init=cql_setup_mid_mixins, -) diff --git a/rllib/agents/cql/cql_sac_torch_policy.py b/rllib/agents/cql/cql_sac_torch_policy.py deleted file mode 100644 index c9fe4c7dad1b..000000000000 --- a/rllib/agents/cql/cql_sac_torch_policy.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -PyTorch policy class used for CQL. -""" -import numpy as np -import gym -import logging -from typing import Dict, List, Tuple, Type, Union - -import ray -import ray.experimental.tf_utils -from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \ - validate_spaces -from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \ - build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \ - TargetNetworkMixin, setup_late_mixins, action_distribution_fn -from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy_template import build_policy_class -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ - TrainerConfigDict -from ray.rllib.utils.torch_ops import apply_grad_clipping, \ - convert_to_torch_tensor - -torch, nn = try_import_torch() -F = nn.functional - -logger = logging.getLogger(__name__) - - -# Returns policy tiled actions and log probabilities for CQL Loss -def policy_actions_repeat(model, action_dist, obs, num_repeat=1): - obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view( - obs.shape[0] * num_repeat, obs.shape[1]) - policy_dist = action_dist(model.get_policy_output(obs_temp), model) - actions = policy_dist.sample() - log_p = torch.unsqueeze(policy_dist.logp(actions), -1) - return actions, log_p.squeeze() - - -def q_values_repeat(model, obs, actions, twin=False): - action_shape = actions.shape[0] - obs_shape = obs.shape[0] - num_repeat = int(action_shape / obs_shape) - obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view( - obs.shape[0] * num_repeat, obs.shape[1]) - if not twin: - preds = model.get_q_values(obs_temp, actions) - else: - preds = model.get_twin_q_values(obs_temp, actions) - preds = preds.view(obs.shape[0], num_repeat, 1) - return preds - - -def cql_loss(policy: Policy, model: ModelV2, - dist_class: Type[TorchDistributionWrapper], - train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: - print(policy.cur_iter) - policy.cur_iter += 1 - # For best performance, turn deterministic off - deterministic = policy.config["_deterministic_loss"] - twin_q = policy.config["twin_q"] - discount = policy.config["gamma"] - action_low = model.action_space.low[0] - action_high = model.action_space.high[0] - - # CQL Parameters - bc_iters = policy.config["bc_iters"] - cql_temp = policy.config["temperature"] - num_actions = policy.config["num_actions"] - min_q_weight = policy.config["min_q_weight"] - use_lagrange = policy.config["lagrangian"] - target_action_gap = policy.config["lagrangian_thresh"] - - obs = train_batch[SampleBatch.CUR_OBS] - actions = train_batch[SampleBatch.ACTIONS] - rewards = train_batch[SampleBatch.REWARDS] - next_obs = train_batch[SampleBatch.NEXT_OBS] - terminals = train_batch[SampleBatch.DONES] - - model_out_t, _ = model({ - "obs": obs, - "is_training": True, - }, [], None) - - model_out_tp1, _ = model({ - "obs": next_obs, - "is_training": True, - }, [], None) - - target_model_out_tp1, _ = policy.target_model({ - "obs": next_obs, - "is_training": True, - }, [], None) - - action_dist_class = _get_dist_class(policy.config, policy.action_space) - action_dist_t = action_dist_class( - model.get_policy_output(model_out_t), policy.model) - policy_t = action_dist_t.sample() if not deterministic else \ - action_dist_t.deterministic_sample() - log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) - - # Unlike original SAC, Alpha and Actor Loss are computed first. - # Alpha Loss - alpha_loss = -(model.log_alpha * - (log_pis_t + model.target_entropy).detach()).mean() - - # Policy Loss (Either Behavior Clone Loss or SAC Loss) - alpha = torch.exp(model.log_alpha) - if policy.cur_iter >= bc_iters: - min_q = model.get_q_values(model_out_t, policy_t) - if twin_q: - twin_q_ = model.get_twin_q_values(model_out_t, policy_t) - min_q = torch.min(min_q, twin_q_) - actor_loss = (alpha.detach() * log_pis_t - min_q).mean() - else: - bc_logp = action_dist_t.logp(actions) - actor_loss = (alpha * log_pis_t - bc_logp).mean() - - # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) - # SAC Loss - action_dist_tp1 = action_dist_class( - model.get_policy_output(model_out_tp1), policy.model) - policy_tp1 = action_dist_tp1.sample() if not deterministic else \ - action_dist_tp1.deterministic_sample() - - # Q-values for the batched actions. - q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) - if twin_q: - twin_q_t = model.get_twin_q_values(model_out_t, - train_batch[SampleBatch.ACTIONS]) - - # Target q network evaluation. - q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, - policy_tp1) - if twin_q: - twin_q_tp1 = policy.target_model.get_twin_q_values( - target_model_out_tp1, policy_tp1) - # Take min over both twin-NNs. - q_tp1 = torch.min(q_tp1, twin_q_tp1) - - q_t = torch.squeeze(q_t, dim=-1) - if twin_q: - twin_q_t = torch.squeeze(twin_q_t, dim=-1) - - q_tp1 = torch.squeeze(input=q_tp1, dim=-1) - q_tp1 = (1.0 - terminals.float()) * q_tp1 - - # compute RHS of bellman equation - q_t_target = ( - rewards + (discount**policy.config["n_step"]) * q_tp1).detach() - - # Compute the TD-error (potentially clipped), for priority replay buffer - base_td_error = torch.abs(q_t - q_t_target) - if twin_q: - twin_td_error = torch.abs(twin_q_t - q_t_target) - td_error = 0.5 * (base_td_error + twin_td_error) - else: - td_error = base_td_error - critic_loss = [nn.MSELoss()(q_t, q_t_target)] - if twin_q: - critic_loss.append(nn.MSELoss()(twin_q_t, q_t_target)) - - # CQL Loss (We are using Entropy version of CQL (the best version)) - rand_actions = convert_to_torch_tensor( - torch.FloatTensor(actions.shape[0] * num_actions, - actions.shape[-1]).uniform_(action_low, action_high), - policy.device) - curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, - obs, num_actions) - next_actions, next_logp = policy_actions_repeat(model, action_dist_class, - next_obs, num_actions) - - curr_logp = curr_logp.view(actions.shape[0], num_actions, 1) - next_logp = next_logp.view(actions.shape[0], num_actions, 1) - - q1_rand = q_values_repeat(model, model_out_t, rand_actions) - q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) - q1_next_actions = q_values_repeat(model, model_out_t, next_actions) - - if twin_q: - q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) - q2_curr_actions = q_values_repeat( - model, model_out_t, curr_actions, twin=True) - q2_next_actions = q_values_repeat( - model, model_out_t, next_actions, twin=True) - - random_density = np.log(0.5**curr_actions.shape[-1]) - cat_q1 = torch.cat([ - q1_rand - random_density, q1_next_actions - next_logp.detach(), - q1_curr_actions - curr_logp.detach() - ], 1) - if twin_q: - cat_q2 = torch.cat([ - q2_rand - random_density, q2_next_actions - next_logp.detach(), - q2_curr_actions - curr_logp.detach() - ], 1) - - min_qf1_loss = torch.logsumexp( - cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp - min_qf1_loss = min_qf1_loss - q_t.mean() * min_q_weight - if twin_q: - min_qf2_loss = torch.logsumexp( - cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp - min_qf2_loss = min_qf2_loss - twin_q_t.mean() * min_q_weight - - if use_lagrange: - alpha_prime = torch.clamp( - model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] - min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) - if twin_q: - min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) - alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) - else: - alpha_prime_loss = -min_qf1_loss - - cql_loss = [min_qf2_loss] - if twin_q: - cql_loss.append(min_qf2_loss) - - critic_loss[0] += min_qf1_loss - if twin_q: - critic_loss[1] += min_qf2_loss - - # Save for stats function. - policy.q_t = q_t - policy.policy_t = policy_t - policy.log_pis_t = log_pis_t - policy.td_error = td_error - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.alpha_loss = alpha_loss - policy.log_alpha_value = model.log_alpha - policy.alpha_value = alpha - policy.target_entropy = model.target_entropy - # CQL Stats - policy.cql_loss = cql_loss - if use_lagrange: - policy.log_alpha_prime_value = model.log_alpha_prime[0] - policy.alpha_prime_value = alpha_prime - policy.alpha_prime_loss = alpha_prime_loss - - # Return all loss terms corresponding to our optimizers. - if use_lagrange: - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss] + [policy.alpha_prime_loss]) - return tuple([policy.actor_loss] + policy.critic_loss + - [policy.alpha_loss]) - - -def cql_stats(policy: Policy, - train_batch: SampleBatch) -> Dict[str, TensorType]: - sac_dict = stats(policy, train_batch) - sac_dict["cql_loss"] = torch.mean(torch.stack(policy.cql_loss)) - if policy.config["lagrangian"]: - sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value - sac_dict["alpha_prime_value"] = policy.alpha_prime_value - sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss - return sac_dict - - -def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \ - Tuple[LocalOptimizer]: - policy.cur_iter = 0 - opt_list = optimizer_fn(policy, config) - if config["lagrangian"]: - log_alpha_prime = nn.Parameter( - torch.zeros(1, requires_grad=True).float()) - policy.model.register_parameter("log_alpha_prime", log_alpha_prime) - policy.alpha_prime_optim = torch.optim.Adam( - params=[policy.model.log_alpha_prime], - lr=config["optimization"]["critic_learning_rate"], - eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default - ) - return tuple([policy.actor_optim] + policy.critic_optims + - [policy.alpha_optim] + [policy.alpha_prime_optim]) - return opt_list - - -def cql_setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict) -> None: - setup_late_mixins(policy, obs_space, action_space, config) - if config["lagrangian"]: - policy.model.log_alpha_prime = policy.model.log_alpha_prime.to( - policy.device) - - -# Build a child class of `TorchPolicy`, given the custom functions defined -# above. -CQLSACTorchPolicy = build_policy_class( - name="CQLSACTorchPolicy", - framework="torch", - loss_fn=cql_loss, - get_default_config=lambda: ray.rllib.agents.cql.cql.CQLSAC_DEFAULT_CONFIG, - stats_fn=cql_stats, - postprocess_fn=postprocess_trajectory, - extra_grad_process_fn=apply_grad_clipping, - optimizer_fn=cql_optimizer_fn, - validate_spaces=validate_spaces, - before_loss_init=cql_setup_late_mixins, - make_model_and_action_dist=build_sac_model_and_action_dist, - mixins=[TargetNetworkMixin, ComputeTDErrorMixin], - action_distribution_fn=action_distribution_fn, -) diff --git a/rllib/agents/cql/cql_tf_policy.py b/rllib/agents/cql/cql_tf_policy.py new file mode 100644 index 000000000000..447254d8be77 --- /dev/null +++ b/rllib/agents/cql/cql_tf_policy.py @@ -0,0 +1,425 @@ +""" +TensorFlow policy class used for CQL. +""" +from functools import partial +import numpy as np +import gym +import logging +import tree +from typing import Dict, List, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.agents.sac.sac_tf_policy import ( + apply_gradients as sac_apply_gradients, + compute_and_clip_gradients as sac_compute_and_clip_gradients, + get_distribution_inputs_and_class, + _get_dist_class, + build_sac_model, + postprocess_trajectory, + setup_late_mixins, + stats, + validate_spaces, + ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, + ComputeTDErrorMixin, + TargetNetworkMixin, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.exploration.random import Random +from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, + TrainerConfigDict, +) + +tf1, tf, tfv = try_import_tf() +tfp = try_import_tfp() + +logger = logging.getLogger(__name__) + +MEAN_MIN = -9.0 +MEAN_MAX = 9.0 + + +def _repeat_tensor(t: TensorType, n: int): + # Insert new axis at position 1 into tensor t + t_rep = tf.expand_dims(t, 1) + # Repeat tensor t_rep along new axis n times + multiples = tf.concat([[1, n], tf.tile([1], tf.expand_dims(tf.rank(t) - 1, 0))], 0) + t_rep = tf.tile(t_rep, multiples) + # Merge new axis into batch axis + t_rep = tf.reshape(t_rep, tf.concat([[-1], tf.shape(t)[1:]], 0)) + return t_rep + + +# Returns policy tiled actions and log probabilities for CQL Loss +def policy_actions_repeat(model, action_dist, obs, num_repeat=1): + batch_size = tf.shape(tree.flatten(obs)[0])[0] + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + logits, _ = model.get_action_model_outputs(obs_temp) + policy_dist = action_dist(logits, model) + actions, logp_ = policy_dist.sample_logp() + logp = tf.expand_dims(logp_, -1) + return actions, tf.reshape(logp, [batch_size, num_repeat, 1]) + + +def q_values_repeat(model, obs, actions, twin=False): + action_shape = tf.shape(actions)[0] + obs_shape = tf.shape(tree.flatten(obs)[0])[0] + num_repeat = action_shape // obs_shape + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + if not twin: + preds_, _ = model.get_q_values(obs_temp, actions) + else: + preds_, _ = model.get_twin_q_values(obs_temp, actions) + preds = tf.reshape(preds_, [obs_shape, num_repeat, 1]) + return preds + + +def cql_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + logger.info(f"Current iteration = {policy.cur_iter}") + policy.cur_iter += 1 + + # For best performance, turn deterministic off + deterministic = policy.config["_deterministic_loss"] + assert not deterministic + twin_q = policy.config["twin_q"] + discount = policy.config["gamma"] + + # CQL Parameters + bc_iters = policy.config["bc_iters"] + cql_temp = policy.config["temperature"] + num_actions = policy.config["num_actions"] + min_q_weight = policy.config["min_q_weight"] + use_lagrange = policy.config["lagrangian"] + target_action_gap = policy.config["lagrangian_thresh"] + + obs = train_batch[SampleBatch.CUR_OBS] + actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) + rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + next_obs = train_batch[SampleBatch.NEXT_OBS] + terminals = train_batch[SampleBatch.DONES] + + model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) + + model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) + + target_model_out_tp1, _ = policy.target_model( + SampleBatch(obs=next_obs, _is_training=True), [], None + ) + + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, model) + policy_t, log_pis_t = action_dist_t.sample_logp() + log_pis_t = tf.expand_dims(log_pis_t, -1) + + # Unlike original SAC, Alpha and Actor Loss are computed first. + # Alpha Loss + alpha_loss = -tf.reduce_mean( + model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy) + ) + + # Policy Loss (Either Behavior Clone Loss or SAC Loss) + alpha = tf.math.exp(model.log_alpha) + if policy.cur_iter >= bc_iters: + min_q, _ = model.get_q_values(model_out_t, policy_t) + if twin_q: + twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) + min_q = tf.math.minimum(min_q, twin_q_) + actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - min_q) + else: + bc_logp = action_dist_t.logp(actions) + actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - bc_logp) + # actor_loss = -tf.reduce_mean(bc_logp) + + # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) + # SAC Loss: + # Q-values for the batched actions. + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) + policy_tp1, _ = action_dist_tp1.sample_logp() + + q_t, _ = model.get_q_values(model_out_t, actions) + q_t_selected = tf.squeeze(q_t, axis=-1) + if twin_q: + twin_q_t, _ = model.get_twin_q_values(model_out_t, actions) + twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) + + # Target q network evaluation. + q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) + if twin_q: + twin_q_tp1, _ = policy.target_model.get_twin_q_values( + target_model_out_tp1, policy_tp1 + ) + # Take min over both twin-NNs. + q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) + + q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) + q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best + + # compute RHS of bellman equation + q_t_target = tf.stop_gradient( + rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked + ) + + # Compute the TD-error (potentially clipped), for priority replay buffer + base_td_error = tf.math.abs(q_t_selected - q_t_target) + if twin_q: + twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) + if twin_q: + critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) + + # CQL Loss (We are using Entropy version of CQL (the best version)) + rand_actions, _ = policy._random_action_generator.get_exploration_action( + action_distribution=action_dist_class( + tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model + ), + timestep=0, + explore=True, + ) + curr_actions, curr_logp = policy_actions_repeat( + model, action_dist_class, model_out_t, num_actions + ) + next_actions, next_logp = policy_actions_repeat( + model, action_dist_class, model_out_tp1, num_actions + ) + + q1_rand = q_values_repeat(model, model_out_t, rand_actions) + q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) + q1_next_actions = q_values_repeat(model, model_out_t, next_actions) + + if twin_q: + q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) + q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) + q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) + + random_density = np.log(0.5 ** int(curr_actions.shape[-1])) + cat_q1 = tf.concat( + [ + q1_rand - random_density, + q1_next_actions - tf.stop_gradient(next_logp), + q1_curr_actions - tf.stop_gradient(curr_logp), + ], + 1, + ) + if twin_q: + cat_q2 = tf.concat( + [ + q2_rand - random_density, + q2_next_actions - tf.stop_gradient(next_logp), + q2_curr_actions - tf.stop_gradient(curr_logp), + ], + 1, + ) + + min_qf1_loss_ = ( + tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) + * min_q_weight + * cql_temp + ) + min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) + if twin_q: + min_qf2_loss_ = ( + tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) + * min_q_weight + * cql_temp + ) + min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) + + if use_lagrange: + alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] + min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) + if twin_q: + min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) + alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) + else: + alpha_prime_loss = -min_qf1_loss + + cql_loss = [min_qf1_loss] + if twin_q: + cql_loss.append(min_qf2_loss) + + critic_loss = [critic_loss_1 + min_qf1_loss] + if twin_q: + critic_loss.append(critic_loss_2 + min_qf2_loss) + + # Save for stats function. + policy.q_t = q_t_selected + policy.policy_t = policy_t + policy.log_pis_t = log_pis_t + policy.td_error = td_error + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.log_alpha_value = model.log_alpha + policy.alpha_value = alpha + policy.target_entropy = model.target_entropy + # CQL Stats + policy.cql_loss = cql_loss + if use_lagrange: + policy.log_alpha_prime_value = model.log_alpha_prime[0] + policy.alpha_prime_value = alpha_prime + policy.alpha_prime_loss = alpha_prime_loss + + # Return all loss terms corresponding to our optimizers. + if use_lagrange: + return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + alpha_prime_loss + return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + + +def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + sac_dict = stats(policy, train_batch) + sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss)) + if policy.config["lagrangian"]: + sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value + sac_dict["alpha_prime_value"] = policy.alpha_prime_value + sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss + return sac_dict + + +class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin): + def __init__(self, config): + super().__init__(config) + if config["lagrangian"]: + # Eager mode. + if config["framework"] in ["tf2", "tfe"]: + self._alpha_prime_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + # Static graph mode. + else: + self._alpha_prime_optimizer = tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + + +def setup_early_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, +) -> None: + """Call mixin classes' constructors before Policy's initialization. + + Adds the necessary optimizers to the given Policy. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ + policy.cur_iter = 0 + ActorCriticOptimizerMixin.__init__(policy, config) + if config["lagrangian"]: + policy.model.log_alpha_prime = get_variable( + 0.0, framework="tf", trainable=True, tf_name="log_alpha_prime" + ) + policy.alpha_prime_optim = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"], + ) + # Generic random action generator for calculating CQL-loss. + policy._random_action_generator = Random( + action_space, + model=None, + framework="tf2", + policy_config=config, + num_workers=0, + worker_index=0, + ) + + +def compute_gradients_fn( + policy: Policy, optimizer: LocalOptimizer, loss: TensorType +) -> ModelGradients: + grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss) + + if policy.config["lagrangian"]: + # Eager: Use GradientTape (which is a property of the `optimizer` + # object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). + if policy.config["framework"] in ["tf2", "tfe"]: + tape = optimizer.tape + log_alpha_prime = [policy.model.log_alpha_prime] + alpha_prime_grads_and_vars = list( + zip( + tape.gradient(policy.alpha_prime_loss, log_alpha_prime), + log_alpha_prime, + ) + ) + # Tf1.x: Use optimizer.compute_gradients() + else: + alpha_prime_grads_and_vars = ( + policy._alpha_prime_optimizer.compute_gradients( + policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime] + ) + ) + + # Clip if necessary. + if policy.config["grad_clip"]: + clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) + else: + clip_func = tf.identity + + # Save grads and vars for later use in `build_apply_op`. + policy._alpha_prime_grads_and_vars = [ + (clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars if g is not None + ] + + grads_and_vars += policy._alpha_prime_grads_and_vars + return grads_and_vars + + +def apply_gradients_fn(policy, optimizer, grads_and_vars): + sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars) + + if policy.config["lagrangian"]: + # Eager mode -> Just apply and return None. + if policy.config["framework"] in ["tf2", "tfe"]: + policy._alpha_prime_optimizer.apply_gradients( + policy._alpha_prime_grads_and_vars + ) + return + # Tf static graph -> Return grouped op. + else: + alpha_prime_apply_op = policy._alpha_prime_optimizer.apply_gradients( + policy._alpha_prime_grads_and_vars, + global_step=tf1.train.get_or_create_global_step(), + ) + return tf.group([sac_results, alpha_prime_apply_op]) + return sac_results + + +# Build a child class of `TFPolicy`, given the custom functions defined +# above. +CQLTFPolicy = build_tf_policy( + name="CQLTFPolicy", + loss_fn=cql_loss, + get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG, + validate_spaces=validate_spaces, + stats_fn=cql_stats, + postprocess_fn=postprocess_trajectory, + before_init=setup_early_mixins, + after_init=setup_late_mixins, + make_model=build_sac_model, + mixins=[ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin], + action_distribution_fn=get_distribution_inputs_and_class, + compute_gradients_fn=compute_gradients_fn, + apply_gradients_fn=apply_gradients_fn, +) diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py new file mode 100644 index 000000000000..dcf5e58769dc --- /dev/null +++ b/rllib/agents/cql/cql_torch_policy.py @@ -0,0 +1,403 @@ +""" +PyTorch policy class used for CQL. +""" +import numpy as np +import gym +import logging +import tree +from typing import Dict, List, Tuple, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, validate_spaces +from ray.rllib.agents.sac.sac_torch_policy import ( + _get_dist_class, + stats, + build_sac_model_and_action_dist, + optimizer_fn, + ComputeTDErrorMixin, + TargetNetworkMixin, + setup_late_mixins, + action_distribution_fn, +) +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import LocalOptimizer, TensorType, TrainerConfigDict +from ray.rllib.utils.torch_utils import ( + apply_grad_clipping, + convert_to_torch_tensor, + concat_multi_gpu_td_errors, +) + +torch, nn = try_import_torch() +F = nn.functional + +logger = logging.getLogger(__name__) + +MEAN_MIN = -9.0 +MEAN_MAX = 9.0 + + +def _repeat_tensor(t: TensorType, n: int): + # Insert new dimension at posotion 1 into tensor t + t_rep = t.unsqueeze(1) + # Repeat tensor t_rep along new dimension n times + t_rep = torch.repeat_interleave(t_rep, n, dim=1) + # Merge new dimension into batch dimension + t_rep = t_rep.view(-1, *t.shape[1:]) + return t_rep + + +# Returns policy tiled actions and log probabilities for CQL Loss +def policy_actions_repeat(model, action_dist, obs, num_repeat=1): + batch_size = tree.flatten(obs)[0].shape[0] + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + logits, _ = model.get_action_model_outputs(obs_temp) + policy_dist = action_dist(logits, model) + actions, logp_ = policy_dist.sample_logp() + logp = logp_.unsqueeze(-1) + return actions, logp.view(batch_size, num_repeat, 1) + + +def q_values_repeat(model, obs, actions, twin=False): + action_shape = actions.shape[0] + obs_shape = tree.flatten(obs)[0].shape[0] + num_repeat = int(action_shape / obs_shape) + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + if not twin: + preds_, _ = model.get_q_values(obs_temp, actions) + else: + preds_, _ = model.get_twin_q_values(obs_temp, actions) + preds = preds_.view(obs_shape, num_repeat, 1) + return preds + + +def cql_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + logger.info(f"Current iteration = {policy.cur_iter}") + policy.cur_iter += 1 + + # Look up the target model (tower) using the model tower. + target_model = policy.target_models[model] + + # For best performance, turn deterministic off + deterministic = policy.config["_deterministic_loss"] + assert not deterministic + twin_q = policy.config["twin_q"] + discount = policy.config["gamma"] + action_low = model.action_space.low[0] + action_high = model.action_space.high[0] + + # CQL Parameters + bc_iters = policy.config["bc_iters"] + cql_temp = policy.config["temperature"] + num_actions = policy.config["num_actions"] + min_q_weight = policy.config["min_q_weight"] + use_lagrange = policy.config["lagrangian"] + target_action_gap = policy.config["lagrangian_thresh"] + + obs = train_batch[SampleBatch.CUR_OBS] + actions = train_batch[SampleBatch.ACTIONS] + rewards = train_batch[SampleBatch.REWARDS].float() + next_obs = train_batch[SampleBatch.NEXT_OBS] + terminals = train_batch[SampleBatch.DONES] + + model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) + + model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) + + target_model_out_tp1, _ = target_model( + SampleBatch(obs=next_obs, _is_training=True), [], None + ) + + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, model) + policy_t, log_pis_t = action_dist_t.sample_logp() + log_pis_t = torch.unsqueeze(log_pis_t, -1) + + # Unlike original SAC, Alpha and Actor Loss are computed first. + # Alpha Loss + alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() + + batch_size = tree.flatten(obs)[0].shape[0] + if batch_size == policy.config["train_batch_size"]: + policy.alpha_optim.zero_grad() + alpha_loss.backward() + policy.alpha_optim.step() + + # Policy Loss (Either Behavior Clone Loss or SAC Loss) + alpha = torch.exp(model.log_alpha) + if policy.cur_iter >= bc_iters: + min_q, _ = model.get_q_values(model_out_t, policy_t) + if twin_q: + twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) + min_q = torch.min(min_q, twin_q_) + actor_loss = (alpha.detach() * log_pis_t - min_q).mean() + else: + bc_logp = action_dist_t.logp(actions) + actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() + # actor_loss = -bc_logp.mean() + + if batch_size == policy.config["train_batch_size"]: + policy.actor_optim.zero_grad() + actor_loss.backward(retain_graph=True) + policy.actor_optim.step() + + # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) + # SAC Loss: + # Q-values for the batched actions. + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) + policy_tp1, _ = action_dist_tp1.sample_logp() + + q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) + q_t_selected = torch.squeeze(q_t, dim=-1) + if twin_q: + twin_q_t, _ = model.get_twin_q_values( + model_out_t, train_batch[SampleBatch.ACTIONS] + ) + twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) + + # Target q network evaluation. + q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) + if twin_q: + twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1) + # Take min over both twin-NNs. + q_tp1 = torch.min(q_tp1, twin_q_tp1) + + q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) + q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best + + # compute RHS of bellman equation + q_t_target = ( + rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked + ).detach() + + # Compute the TD-error (potentially clipped), for priority replay buffer + base_td_error = torch.abs(q_t_selected - q_t_target) + if twin_q: + twin_td_error = torch.abs(twin_q_t_selected - q_t_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) + if twin_q: + critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) + + # CQL Loss (We are using Entropy version of CQL (the best version)) + rand_actions = convert_to_torch_tensor( + torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_( + action_low, action_high + ), + policy.device, + ) + curr_actions, curr_logp = policy_actions_repeat( + model, action_dist_class, model_out_t, num_actions + ) + next_actions, next_logp = policy_actions_repeat( + model, action_dist_class, model_out_tp1, num_actions + ) + + q1_rand = q_values_repeat(model, model_out_t, rand_actions) + q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) + q1_next_actions = q_values_repeat(model, model_out_t, next_actions) + + if twin_q: + q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) + q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) + q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) + + random_density = np.log(0.5 ** curr_actions.shape[-1]) + cat_q1 = torch.cat( + [ + q1_rand - random_density, + q1_next_actions - next_logp.detach(), + q1_curr_actions - curr_logp.detach(), + ], + 1, + ) + if twin_q: + cat_q2 = torch.cat( + [ + q2_rand - random_density, + q2_next_actions - next_logp.detach(), + q2_curr_actions - curr_logp.detach(), + ], + 1, + ) + + min_qf1_loss_ = ( + torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp + ) + min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) + if twin_q: + min_qf2_loss_ = ( + torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp + ) + min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) + + if use_lagrange: + alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[ + 0 + ] + min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) + if twin_q: + min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) + alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) + else: + alpha_prime_loss = -min_qf1_loss + + cql_loss = [min_qf1_loss] + if twin_q: + cql_loss.append(min_qf2_loss) + + critic_loss = [critic_loss_1 + min_qf1_loss] + if twin_q: + critic_loss.append(critic_loss_2 + min_qf2_loss) + + if batch_size == policy.config["train_batch_size"]: + policy.critic_optims[0].zero_grad() + critic_loss[0].backward(retain_graph=True) + policy.critic_optims[0].step() + + if twin_q: + policy.critic_optims[1].zero_grad() + critic_loss[1].backward(retain_graph=False) + policy.critic_optims[1].step() + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + # SAC stats. + model.tower_stats["q_t"] = q_t_selected + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + model.tower_stats["log_alpha_value"] = model.log_alpha + model.tower_stats["alpha_value"] = alpha + model.tower_stats["target_entropy"] = model.target_entropy + # CQL stats. + model.tower_stats["cql_loss"] = cql_loss + + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error + + if use_lagrange: + model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] + model.tower_stats["alpha_prime_value"] = alpha_prime + model.tower_stats["alpha_prime_loss"] = alpha_prime_loss + + if batch_size == policy.config["train_batch_size"]: + policy.alpha_prime_optim.zero_grad() + alpha_prime_loss.backward() + policy.alpha_prime_optim.step() + + # Return all loss terms corresponding to our optimizers. + return tuple( + [actor_loss] + + critic_loss + + [alpha_loss] + + ([alpha_prime_loss] if use_lagrange else []) + ) + + +def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + # Get SAC loss stats. + stats_dict = stats(policy, train_batch) + + # Add CQL loss stats to the dict. + stats_dict["cql_loss"] = torch.mean( + torch.stack(*policy.get_tower_stats("cql_loss")) + ) + + if policy.config["lagrangian"]: + stats_dict["log_alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("log_alpha_prime_value")) + ) + stats_dict["alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_value")) + ) + stats_dict["alpha_prime_loss"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_loss")) + ) + return stats_dict + + +def cql_optimizer_fn( + policy: Policy, config: TrainerConfigDict +) -> Tuple[LocalOptimizer]: + policy.cur_iter = 0 + opt_list = optimizer_fn(policy, config) + if config["lagrangian"]: + log_alpha_prime = nn.Parameter(torch.zeros(1, requires_grad=True).float()) + policy.model.register_parameter("log_alpha_prime", log_alpha_prime) + policy.alpha_prime_optim = torch.optim.Adam( + params=[policy.model.log_alpha_prime], + lr=config["optimization"]["critic_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default + ) + return tuple( + [policy.actor_optim] + + policy.critic_optims + + [policy.alpha_optim] + + [policy.alpha_prime_optim] + ) + return opt_list + + +def cql_setup_late_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, +) -> None: + setup_late_mixins(policy, obs_space, action_space, config) + if config["lagrangian"]: + policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(policy.device) + + +def compute_gradients_fn(policy, postprocessed_batch): + batches = [policy._lazy_tensor_dict(postprocessed_batch)] + model = policy.model + policy._loss(policy, model, policy.dist_class, batches[0]) + stats = {LEARNER_STATS_KEY: policy._convert_to_numpy(cql_stats(policy, batches[0]))} + return [None, stats] + + +def apply_gradients_fn(policy, gradients): + return + + +# Build a child class of `TorchPolicy`, given the custom functions defined +# above. +CQLTorchPolicy = build_policy_class( + name="CQLTorchPolicy", + framework="torch", + loss_fn=cql_loss, + get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG, + stats_fn=cql_stats, + postprocess_fn=postprocess_trajectory, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=cql_optimizer_fn, + validate_spaces=validate_spaces, + before_loss_init=cql_setup_late_mixins, + make_model_and_action_dist=build_sac_model_and_action_dist, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, + mixins=[TargetNetworkMixin, ComputeTDErrorMixin], + action_distribution_fn=action_distribution_fn, + compute_gradients_fn=compute_gradients_fn, + apply_gradients_fn=apply_gradients_fn, +) diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py new file mode 100644 index 000000000000..92890fdaebe6 --- /dev/null +++ b/rllib/agents/cql/tests/test_cql.py @@ -0,0 +1,143 @@ +import numpy as np +from pathlib import Path +import os +import unittest + +import ray +import ray.rllib.agents.cql as cql +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import ( + check_compute_single_action, + check_train_results, + framework_iterator, +) + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +class TestCQL(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_cql_compilation(self): + """Test whether a CQLTrainer can be built with all frameworks.""" + + # Learns from a historic-data file. + # To generate this data, first run: + # $ ./train.py --run=SAC --env=Pendulum-v1 \ + # --stop='{"timesteps_total": 50000}' \ + # --config='{"output": "/tmp/out"}' + rllib_dir = Path(__file__).parent.parent.parent.parent + print("rllib dir={}".format(rllib_dir)) + data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json") + print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) + + config = { + "env": "Pendulum-v1", + "input": [data_file], + # In the files, we use here for testing, actions have already + # been normalized. + # This is usually the case when the file was generated by another + # RLlib algorithm (e.g. PPO or SAC). + "actions_in_input_normalized": False, + "clip_actions": True, + "train_batch_size": 2000, + "twin_q": True, + "learning_starts": 0, + "bc_iters": 2, # 2 BC iters, 2 CQL iters. + "rollout_fragment_length": 1, + # Switch on off-policy evaluation. + "input_evaluation": ["is"], + "always_attach_evaluation_results": True, + "evaluation_interval": 2, + "evaluation_duration": 10, + "evaluation_config": { + "input": "sampler", + }, + "evaluation_parallel_to_training": False, + "evaluation_num_workers": 2, + } + num_iterations = 4 + + # Test for tf/torch frameworks. + for fw in framework_iterator(config, with_eager_tracing=True): + trainer = cql.CQLTrainer(config=config) + for i in range(num_iterations): + results = trainer.train() + check_train_results(results) + print(results) + eval_results = results["evaluation"] + print( + f"iter={trainer.iteration} " + f"R={eval_results['episode_reward_mean']}" + ) + + check_compute_single_action(trainer) + + # Get policy and model. + pol = trainer.get_policy() + cql_model = pol.model + if fw == "tf": + pol.get_session().__enter__() + + # Example on how to do evaluation on the trained Trainer + # using the data from CQL's global replay buffer. + # Get a sample (MultiAgentBatch -> SampleBatch). + batch = trainer.local_replay_buffer.replay().policy_batches[ + "default_policy" + ] + + if fw == "torch": + obs = torch.from_numpy(batch["obs"]) + else: + obs = batch["obs"] + batch["actions"] = batch["actions"].astype(np.float32) + + # Pass the observations through our model to get the + # features, which then to pass through the Q-head. + model_out, _ = cql_model({"obs": obs}) + # The estimated Q-values from the (historic) actions in the batch. + if fw == "torch": + q_values_old = cql_model.get_q_values( + model_out, torch.from_numpy(batch["actions"]) + ) + else: + q_values_old = cql_model.get_q_values( + tf.convert_to_tensor(model_out), batch["actions"] + ) + + # The estimated Q-values for the new actions computed + # by our trainer policy. + actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0] + if fw == "torch": + q_values_new = cql_model.get_q_values( + model_out, torch.from_numpy(actions_new) + ) + else: + q_values_new = cql_model.get_q_values(model_out, actions_new) + + if fw == "tf": + q_values_old, q_values_new = pol.get_session().run( + [q_values_old, q_values_new] + ) + + print(f"Q-val batch={q_values_old}") + print(f"Q-val policy={q_values_new}") + + if fw == "tf": + pol.get_session().__exit__(None, None, None) + + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/cql/tests/test_cql_sac.py b/rllib/agents/cql/tests/test_cql_sac.py deleted file mode 100644 index ca74b6c86945..000000000000 --- a/rllib/agents/cql/tests/test_cql_sac.py +++ /dev/null @@ -1,622 +0,0 @@ -from gym import Env -from gym.spaces import Box, Discrete, Tuple -import numpy as np -import re -import unittest - -import ray -import ray.rllib.agents.sac as sac -from ray.rllib.agents.cql import CQLSACTrainer, CQLSAC_DEFAULT_CONFIG -from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss -from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \ - loss_torch -from ray.rllib.env.wrappers.moab_wrapper import MOAB_MOVE_TO_CENTER_ENV_NAME -from ray.rllib.examples.env.random_env import RandomEnv -from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \ - TorchBatchNormModel -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.tf.tf_action_dist import Dirichlet -from ray.rllib.models.torch.torch_action_dist import TorchDirichlet -from ray.rllib.execution.replay_buffer import LocalReplayBuffer -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.numpy import fc, huber_loss, relu -from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.test_utils import check, check_compute_single_action, \ - framework_iterator -from ray.rllib.utils.torch_ops import convert_to_torch_tensor - -tf1, tf, tfv = try_import_tf() -torch, _ = try_import_torch() - - -class SimpleEnv(Env): - def __init__(self, config): - if config.get("simplex_actions", False): - self.action_space = Simplex((2, )) - else: - self.action_space = Box(0.0, 1.0, (1, )) - self.observation_space = Box(0.0, 1.0, (1, )) - self.max_steps = config.get("max_steps", 100) - self.state = None - self.steps = None - - def reset(self): - self.state = self.observation_space.sample() - self.steps = 0 - return self.state - - def step(self, action): - self.steps += 1 - # Reward is 1.0 - (max(actions) - state). - [r] = 1.0 - np.abs(np.max(action) - self.state) - d = self.steps >= self.max_steps - self.state = self.observation_space.sample() - return self.state, r, d, {} - - -class TestCQLSAC(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init(local_mode=True) - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_cqlsac_compilation(self): - """Tests whether an SACTrainer can be built with all frameworks.""" - config = CQLSAC_DEFAULT_CONFIG.copy() - config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy() - config["num_workers"] = 0 # Run locally. - config["twin_q"] = True - config["clip_actions"] = False - config["normalize_actions"] = True - config["learning_starts"] = 0 - config["prioritized_replay"] = False - config["train_batch_size"] = 256 #10 - config["input"] = "rllib/tests/data/moab/*.json" - config["input_evaluation"] = [] - config["bc_iters"] = 5 - config["temperature"] = 1.0 - config["num_actions"] = 10 - config["lagrangian"] = True # False - # Lagrangian Threshold - config["lagrangian_thresh"] = 5.0 - config["min_q_weight"] = 5.0 - # Initial value to use for the Alpha Prime (in CQL Loss). - config["initial_alpha_prime"] = 1.0 - config["evaluation_config"] = { - "input": "sampler", - "explore": False, - } - config["evaluation_interval"] = 1 - config["evaluation_num_episodes"] = 10 - config["evaluation_num_workers"] = 1 - - num_iterations = 1 - - ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel) - ModelCatalog.register_custom_model("batch_norm_torch", - TorchBatchNormModel) - - image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) - simple_space = Box(-1.0, 1.0, shape=(3, )) - - # frameworks=("tf2", "tf", "tfe", "torch") - for fw in framework_iterator(config=config, frameworks=("tf2", "tf", "tfe")): - # Test for different env types (discrete w/ and w/o image, + cont). - for env in [ - # RandomEnv, - # "MsPacmanNoFrameskip-v4", - # "CartPole-v0", - MOAB_MOVE_TO_CENTER_ENV_NAME, - ]: - print("Env={}".format(env)) - if env == RandomEnv: - config["env_config"] = { - "observation_space": Tuple( - [simple_space, - Discrete(2), image_space]), - "action_space": Box(-1.0, 1.0, shape=(1, )), - } - else: - config["env_config"] = {} - # Test making the Q-model a custom one for CartPole, otherwise, - # use the default model. - config["Q_model"]["custom_model"] = "batch_norm{}".format( - "_torch" - if fw == "torch" else "") if env == "CartPole-v0" else None - trainer = CQLSACTrainer(config=config, env=env) - for i in range(num_iterations): - results = trainer.train() - print(results) - check_compute_single_action(trainer) - trainer.stop() - - @unittest.skip("TODO(Edi): Adapt...") - def test_cqlsac_loss_function(self): - self.skipTest("TODO(Edi): Adapt...") - """Tests SAC loss function results across all frameworks.""" - config = sac.DEFAULT_CONFIG.copy() - # Run locally. - config["num_workers"] = 0 - config["learning_starts"] = 0 - config["twin_q"] = False - config["gamma"] = 0.99 - # Switch on deterministic loss so we can compare the loss values. - config["_deterministic_loss"] = True - # Use very simple nets. - config["Q_model"]["fcnet_hiddens"] = [10] - config["policy_model"]["fcnet_hiddens"] = [10] - # Make sure, timing differences do not affect trainer.train(). - config["min_iter_time_s"] = 0 - # Test SAC with Simplex action space. - config["env_config"] = {"simplex_actions": True} - - map_ = { - # Action net. - "default_policy/fc_1/kernel": "action_model._hidden_layers.0." - "_model.0.weight", - "default_policy/fc_1/bias": "action_model._hidden_layers.0." - "_model.0.bias", - "default_policy/fc_out/kernel": "action_model." - "_logits._model.0.weight", - "default_policy/fc_out/bias": "action_model._logits._model.0.bias", - "default_policy/value_out/kernel": "action_model." - "_value_branch._model.0.weight", - "default_policy/value_out/bias": "action_model." - "_value_branch._model.0.bias", - # Q-net. - "default_policy/fc_1_1/kernel": "q_net." - "_hidden_layers.0._model.0.weight", - "default_policy/fc_1_1/bias": "q_net." - "_hidden_layers.0._model.0.bias", - "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight", - "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias", - "default_policy/value_out_1/kernel": "q_net." - "_value_branch._model.0.weight", - "default_policy/value_out_1/bias": "q_net." - "_value_branch._model.0.bias", - "default_policy/log_alpha": "log_alpha", - # Target action-net. - "default_policy/fc_1_2/kernel": "action_model." - "_hidden_layers.0._model.0.weight", - "default_policy/fc_1_2/bias": "action_model." - "_hidden_layers.0._model.0.bias", - "default_policy/fc_out_2/kernel": "action_model." - "_logits._model.0.weight", - "default_policy/fc_out_2/bias": "action_model." - "_logits._model.0.bias", - "default_policy/value_out_2/kernel": "action_model." - "_value_branch._model.0.weight", - "default_policy/value_out_2/bias": "action_model." - "_value_branch._model.0.bias", - # Target Q-net - "default_policy/fc_1_3/kernel": "q_net." - "_hidden_layers.0._model.0.weight", - "default_policy/fc_1_3/bias": "q_net." - "_hidden_layers.0._model.0.bias", - "default_policy/fc_out_3/kernel": "q_net." - "_logits._model.0.weight", - "default_policy/fc_out_3/bias": "q_net." - "_logits._model.0.bias", - "default_policy/value_out_3/kernel": "q_net." - "_value_branch._model.0.weight", - "default_policy/value_out_3/bias": "q_net." - "_value_branch._model.0.bias", - "default_policy/log_alpha_1": "log_alpha", - } - - env = SimpleEnv - batch_size = 100 - if env is SimpleEnv: - obs_size = (batch_size, 1) - actions = np.random.random(size=(batch_size, 2)) - elif env == "CartPole-v0": - obs_size = (batch_size, 4) - actions = np.random.randint(0, 2, size=(batch_size, )) - else: - obs_size = (batch_size, 3) - actions = np.random.random(size=(batch_size, 1)) - - # Batch of size=n. - input_ = self._get_batch_helper(obs_size, actions, batch_size) - - # Simply compare loss values AND grads of all frameworks with each - # other. - prev_fw_loss = weights_dict = None - expect_c, expect_a, expect_e, expect_t = None, None, None, None - # History of tf-updated NN-weights over n training steps. - tf_updated_weights = [] - # History of input batches used. - tf_inputs = [] - for fw, sess in framework_iterator( - config, frameworks=("tf", "torch"), session=True): - # Generate Trainer and get its default Policy object. - trainer = sac.SACTrainer(config=config, env=env) - policy = trainer.get_policy() - p_sess = None - if sess: - p_sess = policy.get_session() - - # Set all weights (of all nets) to fixed values. - if weights_dict is None: - # Start with the tf vars-dict. - assert fw in ["tf2", "tf", "tfe"] - weights_dict = policy.get_weights() - if fw == "tfe": - log_alpha = weights_dict[10] - weights_dict = self._translate_tfe_weights( - weights_dict, map_) - else: - assert fw == "torch" # Then transfer that to torch Model. - model_dict = self._translate_weights_to_torch( - weights_dict, map_) - policy.model.load_state_dict(model_dict) - policy.target_model.load_state_dict(model_dict) - - if fw == "tf": - log_alpha = weights_dict["default_policy/log_alpha"] - elif fw == "torch": - # Actually convert to torch tensors (by accessing everything). - input_ = policy._lazy_tensor_dict(input_) - input_ = {k: input_[k] for k in input_.keys()} - log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0] - - # Only run the expectation once, should be the same anyways - # for all frameworks. - if expect_c is None: - expect_c, expect_a, expect_e, expect_t = \ - self._sac_loss_helper(input_, weights_dict, - sorted(weights_dict.keys()), - log_alpha, fw, - gamma=config["gamma"], sess=sess) - - # Get actual outs and compare to expectation AND previous - # framework. c=critic, a=actor, e=entropy, t=td-error. - if fw == "tf": - c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \ - p_sess.run([ - policy.critic_loss, - policy.actor_loss, - policy.alpha_loss, - policy.td_error, - policy.optimizer().compute_gradients( - policy.critic_loss[0], - [v for v in policy.model.q_variables() if - "value_" not in v.name]), - policy.optimizer().compute_gradients( - policy.actor_loss, - [v for v in policy.model.policy_variables() if - "value_" not in v.name]), - policy.optimizer().compute_gradients( - policy.alpha_loss, policy.model.log_alpha)], - feed_dict=policy._get_loss_inputs_dict( - input_, shuffle=False)) - tf_c_grads = [g for g, v in tf_c_grads] - tf_a_grads = [g for g, v in tf_a_grads] - tf_e_grads = [g for g, v in tf_e_grads] - - elif fw == "tfe": - with tf.GradientTape() as tape: - tf_loss(policy, policy.model, None, input_) - c, a, e, t = policy.critic_loss, policy.actor_loss, \ - policy.alpha_loss, policy.td_error - vars = tape.watched_variables() - tf_c_grads = tape.gradient(c[0], vars[6:10]) - tf_a_grads = tape.gradient(a, vars[2:6]) - tf_e_grads = tape.gradient(e, vars[10]) - - elif fw == "torch": - loss_torch(policy, policy.model, None, input_) - c, a, e, t = policy.critic_loss, policy.actor_loss, \ - policy.alpha_loss, policy.td_error - - # Test actor gradients. - policy.actor_optim.zero_grad() - assert all(v.grad is None for v in policy.model.q_variables()) - assert all( - v.grad is None for v in policy.model.policy_variables()) - assert policy.model.log_alpha.grad is None - a.backward() - # `actor_loss` depends on Q-net vars (but these grads must - # be ignored and overridden in critic_loss.backward!). - assert not all( - torch.mean(v.grad) == 0 - for v in policy.model.policy_variables()) - assert not all( - torch.min(v.grad) == 0 - for v in policy.model.policy_variables()) - assert policy.model.log_alpha.grad is None - # Compare with tf ones. - torch_a_grads = [ - v.grad for v in policy.model.policy_variables() - if v.grad is not None - ] - check(tf_a_grads[2], - np.transpose(torch_a_grads[0].detach().cpu())) - - # Test critic gradients. - policy.critic_optims[0].zero_grad() - assert all( - torch.mean(v.grad) == 0.0 - for v in policy.model.q_variables() if v.grad is not None) - assert all( - torch.min(v.grad) == 0.0 - for v in policy.model.q_variables() if v.grad is not None) - assert policy.model.log_alpha.grad is None - c[0].backward() - assert not all( - torch.mean(v.grad) == 0 - for v in policy.model.q_variables() if v.grad is not None) - assert not all( - torch.min(v.grad) == 0 for v in policy.model.q_variables() - if v.grad is not None) - assert policy.model.log_alpha.grad is None - # Compare with tf ones. - torch_c_grads = [v.grad for v in policy.model.q_variables()] - check(tf_c_grads[0], - np.transpose(torch_c_grads[2].detach().cpu())) - # Compare (unchanged(!) actor grads) with tf ones. - torch_a_grads = [ - v.grad for v in policy.model.policy_variables() - ] - check(tf_a_grads[2], - np.transpose(torch_a_grads[0].detach().cpu())) - - # Test alpha gradient. - policy.alpha_optim.zero_grad() - assert policy.model.log_alpha.grad is None - e.backward() - assert policy.model.log_alpha.grad is not None - check(policy.model.log_alpha.grad, tf_e_grads) - - check(c, expect_c) - check(a, expect_a) - check(e, expect_e) - check(t, expect_t) - - # Store this framework's losses in prev_fw_loss to compare with - # next framework's outputs. - if prev_fw_loss is not None: - check(c, prev_fw_loss[0]) - check(a, prev_fw_loss[1]) - check(e, prev_fw_loss[2]) - check(t, prev_fw_loss[3]) - - prev_fw_loss = (c, a, e, t) - - # Update weights from our batch (n times). - for update_iteration in range(5): - print("train iteration {}".format(update_iteration)) - if fw == "tf": - in_ = self._get_batch_helper(obs_size, actions, batch_size) - tf_inputs.append(in_) - # Set a fake-batch to use - # (instead of sampling from replay buffer). - buf = LocalReplayBuffer.get_instance_for_testing() - buf._fake_batch = in_ - trainer.train() - updated_weights = policy.get_weights() - # Net must have changed. - if tf_updated_weights: - check( - updated_weights["default_policy/fc_1/kernel"], - tf_updated_weights[-1][ - "default_policy/fc_1/kernel"], - false=True) - tf_updated_weights.append(updated_weights) - - # Compare with updated tf-weights. Must all be the same. - else: - tf_weights = tf_updated_weights[update_iteration] - in_ = tf_inputs[update_iteration] - # Set a fake-batch to use - # (instead of sampling from replay buffer). - buf = LocalReplayBuffer.get_instance_for_testing() - buf._fake_batch = in_ - trainer.train() - # Compare updated model. - for tf_key in sorted(tf_weights.keys()): - if re.search("_[23]|alpha", tf_key): - continue - tf_var = tf_weights[tf_key] - torch_var = policy.model.state_dict()[map_[tf_key]] - if tf_var.shape != torch_var.shape: - check( - tf_var, - np.transpose(torch_var.detach().cpu()), - rtol=0.05) - else: - check(tf_var, torch_var, rtol=0.05) - # And alpha. - check(policy.model.log_alpha, - tf_weights["default_policy/log_alpha"]) - # Compare target nets. - for tf_key in sorted(tf_weights.keys()): - if not re.search("_[23]", tf_key): - continue - tf_var = tf_weights[tf_key] - torch_var = policy.target_model.state_dict()[map_[ - tf_key]] - if tf_var.shape != torch_var.shape: - check( - tf_var, - np.transpose(torch_var.detach().cpu()), - rtol=0.05) - else: - check(tf_var, torch_var, rtol=0.05) - - def _get_batch_helper(self, obs_size, actions, batch_size): - return { - SampleBatch.CUR_OBS: np.random.random(size=obs_size), - SampleBatch.ACTIONS: actions, - SampleBatch.REWARDS: np.random.random(size=(batch_size, )), - SampleBatch.DONES: np.random.choice( - [True, False], size=(batch_size, )), - SampleBatch.NEXT_OBS: np.random.random(size=obs_size), - "weights": np.random.random(size=(batch_size, )), - } - - def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma, - sess): - """Emulates SAC loss functions for tf and torch.""" - # ks: - # 0=log_alpha - # 1=target log-alpha (not used) - - # 2=action hidden bias - # 3=action hidden kernel - # 4=action out bias - # 5=action out kernel - - # 6=Q hidden bias - # 7=Q hidden kernel - # 8=Q out bias - # 9=Q out kernel - - # 14=target Q hidden bias - # 15=target Q hidden kernel - # 16=target Q out bias - # 17=target Q out kernel - alpha = np.exp(log_alpha) - # cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian - cls = TorchDirichlet if fw == "torch" else Dirichlet - model_out_t = train_batch[SampleBatch.CUR_OBS] - model_out_tp1 = train_batch[SampleBatch.NEXT_OBS] - target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS] - - # get_policy_output - action_dist_t = cls( - fc( - relu( - fc(model_out_t, - weights[ks[1]], - weights[ks[0]], - framework=fw)), weights[ks[9]], weights[ks[8]]), None) - policy_t = action_dist_t.deterministic_sample() - log_pis_t = action_dist_t.logp(policy_t) - if sess: - log_pis_t = sess.run(log_pis_t) - policy_t = sess.run(policy_t) - log_pis_t = np.expand_dims(log_pis_t, -1) - - # Get policy output for t+1. - action_dist_tp1 = cls( - fc( - relu( - fc(model_out_tp1, - weights[ks[1]], - weights[ks[0]], - framework=fw)), weights[ks[9]], weights[ks[8]]), None) - policy_tp1 = action_dist_tp1.deterministic_sample() - log_pis_tp1 = action_dist_tp1.logp(policy_tp1) - if sess: - log_pis_tp1 = sess.run(log_pis_tp1) - policy_tp1 = sess.run(policy_tp1) - log_pis_tp1 = np.expand_dims(log_pis_tp1, -1) - - # Q-values for the actually selected actions. - # get_q_values - q_t = fc( - relu( - fc(np.concatenate( - [model_out_t, train_batch[SampleBatch.ACTIONS]], -1), - weights[ks[3]], - weights[ks[2]], - framework=fw)), - weights[ks[11]], - weights[ks[10]], - framework=fw) - - # Q-values for current policy in given current state. - # get_q_values - q_t_det_policy = fc( - relu( - fc(np.concatenate([model_out_t, policy_t], -1), - weights[ks[3]], - weights[ks[2]], - framework=fw)), - weights[ks[11]], - weights[ks[10]], - framework=fw) - - # Target q network evaluation. - # target_model.get_q_values - if fw == "tf": - q_tp1 = fc( - relu( - fc(np.concatenate([target_model_out_tp1, policy_tp1], -1), - weights[ks[7]], - weights[ks[6]], - framework=fw)), - weights[ks[15]], - weights[ks[14]], - framework=fw) - else: - assert fw == "tfe" - q_tp1 = fc( - relu( - fc(np.concatenate([target_model_out_tp1, policy_tp1], -1), - weights[ks[7]], - weights[ks[6]], - framework=fw)), - weights[ks[9]], - weights[ks[8]], - framework=fw) - - q_t_selected = np.squeeze(q_t, axis=-1) - q_tp1 -= alpha * log_pis_tp1 - q_tp1_best = np.squeeze(q_tp1, axis=-1) - dones = train_batch[SampleBatch.DONES] - rewards = train_batch[SampleBatch.REWARDS] - if fw == "torch": - dones = dones.float().numpy() - rewards = rewards.numpy() - q_tp1_best_masked = (1.0 - dones) * q_tp1_best - q_t_selected_target = rewards + gamma * q_tp1_best_masked - base_td_error = np.abs(q_t_selected - q_t_selected_target) - td_error = base_td_error - critic_loss = [ - np.mean(train_batch["weights"] * - huber_loss(q_t_selected_target - q_t_selected)) - ] - target_entropy = -np.prod((1, )) - alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy)) - actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy) - - return critic_loss, actor_loss, alpha_loss, td_error - - def _translate_weights_to_torch(self, weights_dict, map_): - model_dict = { - map_[k]: convert_to_torch_tensor( - np.transpose(v) if re.search("kernel", k) else np.array([v]) - if re.search("log_alpha", k) else v) - for i, (k, v) in enumerate(weights_dict.items()) if i < 13 - } - - return model_dict - - def _translate_tfe_weights(self, weights_dict, map_): - model_dict = { - "default_policy/log_alpha": None, - "default_policy/log_alpha_target": None, - "default_policy/sequential/action_1/kernel": weights_dict[2], - "default_policy/sequential/action_1/bias": weights_dict[3], - "default_policy/sequential/action_out/kernel": weights_dict[4], - "default_policy/sequential/action_out/bias": weights_dict[5], - "default_policy/sequential_1/q_hidden_0/kernel": weights_dict[6], - "default_policy/sequential_1/q_hidden_0/bias": weights_dict[7], - "default_policy/sequential_1/q_out/kernel": weights_dict[8], - "default_policy/sequential_1/q_out/bias": weights_dict[9], - "default_policy/value_out/kernel": weights_dict[0], - "default_policy/value_out/bias": weights_dict[1], - } - return model_dict - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index a141783c5b2a..78b7834b995a 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -41,11 +41,6 @@ def _import_bc(): return marwil.BCTrainer, marwil.DEFAULT_CONFIG -def _import_cql_sac(): - from ray.rllib.agents import cql - return cql.CQLSACTrainer, cql.CQLSAC_DEFAULT_CONFIG - - def _import_ddpg(): from ray.rllib.agents import ddpg return ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG @@ -121,6 +116,11 @@ def _import_apex_sac(): return sac.ApexSACTrainer, sac.APEX_SAC_DEFAULT_CONFIG +def _import_cql(): + from ray.rllib.agents import cql + return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG + + def _import_simple_q(): from ray.rllib.agents import dqn return dqn.SimpleQTrainer, dqn.simple_q.DEFAULT_CONFIG @@ -144,7 +144,6 @@ def _import_td3(): "APPO": _import_appo, "ARS": _import_ars, "BC": _import_bc, - "CQL_SAC": _import_cql_sac, "ES": _import_es, "DDPG": _import_ddpg, "DDPPO": _import_ddppo, @@ -161,6 +160,7 @@ def _import_td3(): "R2D2": _import_r2d2, "SAC": _import_sac, "APEX_SAC": _import_apex_sac, + "CQL": _import_cql, "SimpleQ": _import_simple_q, "TD3": _import_td3, }