diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 7bc68be5b9ed..46743c1c91ed 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -40,7 +40,7 @@ # yapf: disable # __sphinx_doc_begin__ -APEX_DEFAULT_CONFIG = merge_dicts( +APEX_DEFAULT_CONFIG = DQNTrainer.merge_trainer_configs( DQN_CONFIG, # see also the options in dqn.py, which are also supported { "optimizer": merge_dicts( @@ -75,7 +75,10 @@ # we report metrics from the workers with the lowest # 1/worker_amount_to_collect_metrics_from of epsilons "worker_amount_to_collect_metrics_from": 3, + "custom_resources_per_replay_buffer": {}, }, + _allow_unknown_configs=True, + _allow_unknown_subkeys=["custom_resources_per_replay_buffer"], ) # __sphinx_doc_end__ # yapf: enable @@ -154,19 +157,36 @@ def apex_execution_plan(workers: WorkerSet, num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actor_cls = ReplayActor if config[ "prioritized_replay"] else VanillaReplayActor - replay_actors = create_colocated( - replay_actor_cls, - [ - num_replay_buffer_shards, - config["learning_starts"], - config["buffer_size"], - config["train_batch_size"], - config["prioritized_replay_alpha"], - config["prioritized_replay_beta"], - config["prioritized_replay_eps"], - config["multiagent"]["replay_mode"], - config.get("replay_sequence_length", 1), - ], num_replay_buffer_shards) + custom_resources = config.get("custom_resources_per_replay_buffer") + if custom_resources: + replay_actors = [ + replay_actor_cls.options(resources=custom_resources).remote( + num_replay_buffer_shards, + config["learning_starts"], + config["buffer_size"], + config["train_batch_size"], + config["prioritized_replay_alpha"], + config["prioritized_replay_beta"], + config["prioritized_replay_eps"], + config["multiagent"]["replay_mode"], + config.get("replay_sequence_length", 1), + ) + for _ in range(num_replay_buffer_shards) + ] + else: + replay_actors = create_colocated( + replay_actor_cls, + [ + num_replay_buffer_shards, + config["learning_starts"], + config["buffer_size"], + config["train_batch_size"], + config["prioritized_replay_alpha"], + config["prioritized_replay_beta"], + config["prioritized_replay_eps"], + config["multiagent"]["replay_mode"], + config.get("replay_sequence_length", 1), + ], num_replay_buffer_shards) # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) @@ -285,4 +305,5 @@ def apex_validate_config(config): validate_config=apex_validate_config, execution_plan=apex_execution_plan, mixins=[OverrideDefaultResourceRequest], + allow_unknown_subkeys=["custom_resources_per_replay_buffer"] ) diff --git a/rllib/agents/sac/apex.py b/rllib/agents/sac/apex.py index cb2d43103040..fade18628226 100644 --- a/rllib/agents/sac/apex.py +++ b/rllib/agents/sac/apex.py @@ -36,7 +36,10 @@ # This only applies if async mode is used (above config setting). # Controls the max number of async requests in flight per actor "parallel_rollouts_num_async": 2, + "custom_resources_per_replay_buffer": {}, }, + _allow_unknown_configs=True, + _allow_unknown_subkeys=["custom_resources_per_replay_buffer"], ) @@ -48,4 +51,5 @@ name="APEX_SAC", default_config=APEX_SAC_DEFAULT_CONFIG, execution_plan=apex_execution_plan, + allow_unknown_subkeys=["custom_resources_per_replay_buffer"] ) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 1ec3a5ad7978..d6ff21261146 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1174,7 +1174,8 @@ def resource_help(cls, config: TrainerConfigDict) -> str: def merge_trainer_configs(cls, config1: TrainerConfigDict, config2: PartialTrainerConfigDict, - _allow_unknown_configs: Optional[bool] = None + _allow_unknown_configs: Optional[bool] = None, + _allow_unknown_subkeys: Optional[List[str]] = None, ) -> TrainerConfigDict: config1 = copy.deepcopy(config1) if "callbacks" in config2 and type(config2["callbacks"]) is dict: @@ -1188,8 +1189,10 @@ def make_callbacks(): config2["callbacks"] = make_callbacks if _allow_unknown_configs is None: _allow_unknown_configs = cls._allow_unknown_configs + if _allow_unknown_subkeys is None: + _allow_unknown_subkeys = [] return deep_update(config1, config2, _allow_unknown_configs, - cls._allow_unknown_subkeys, + cls._allow_unknown_subkeys + _allow_unknown_subkeys, cls._override_all_subkeys_if_type_changes) @staticmethod diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 1192d4751e2c..b194330fa18f 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -64,7 +64,8 @@ def build_trainer( mixins: Optional[List[type]] = None, execution_plan: Optional[Callable[[ WorkerSet, TrainerConfigDict - ], Iterable[ResultDict]]] = default_execution_plan) -> Type[Trainer]: + ], Iterable[ResultDict]]] = default_execution_plan, + allow_unknown_subkeys: Optional[List[str]] = None) -> Type[Trainer]: """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: @@ -112,6 +113,8 @@ def build_trainer( original_kwargs = locals().copy() base = add_mixins(Trainer, mixins) + if allow_unknown_subkeys: + Trainer._allow_unknown_subkeys += allow_unknown_subkeys class trainer_cls(base): _name = name