diff --git a/rllib/agents/cql/cql_sac.py b/rllib/agents/cql/cql_sac.py index ed1d7021e57c..4aba097220d9 100644 --- a/rllib/agents/cql/cql_sac.py +++ b/rllib/agents/cql/cql_sac.py @@ -43,6 +43,8 @@ "alpha_upper_bound": 1.0, # Lower bound for alpha value during the lagrangian constraint "alpha_lower_bound": 0.0, + # custom replay buffer + "replay_buffer": None, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 696633d96459..acd3866a03c9 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -216,40 +216,39 @@ def execution_plan(workers: WorkerSet, else: prio_args = {} - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - buffer_size=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) - - input_reader = workers.local_worker().input_reader - is_offline_training = isinstance(input_reader, InMemoryInputReader) - if is_offline_training: - # if we have an InMemoryInputReader, then we are in Offline Training - # which means that we don't need the sampling pipeline setup - for batch in input_reader.get_all(): - local_replay_buffer.add_batch(batch) + if config.get("replay_buffer"): + local_replay_buffer = config.get("replay_buffer") + input_reader = workers.local_worker().input_reader + assert isinstance(input_reader, InMemoryInputReader) + local_replay_buffer = local_replay_buffer(config, prio_args, input_reader) + is_offline_training = True else: - parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync") - num_async = config.get("parallel_rollouts_num_async") - # This could be set to None explicitly - if not num_async: - num_async = 1 - rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async) - - # We execute the following steps concurrently: - # (1) Generate rollouts and store them in our local replay buffer. Calling - # next() on store_op drives this. - store_op = rollouts.for_each( - StoreToReplayBuffer(local_buffer=local_replay_buffer)) - if config.get("execution_plan_custom_store_ops"): - custom_store_ops = config["execution_plan_custom_store_ops"] - store_op = store_op.for_each(custom_store_ops(workers, config)) + local_replay_buffer = LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) + + parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync") + num_async = config.get("parallel_rollouts_num_async") + # This could be set to None explicitly + if not num_async: + num_async = 1 + rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async) + + # We execute the following steps concurrently: + # (1) Generate rollouts and store them in our local replay buffer. Calling + # next() on store_op drives this. + store_op = rollouts.for_each( + StoreToReplayBuffer(local_buffer=local_replay_buffer)) + if config.get("execution_plan_custom_store_ops"): + custom_store_ops = config["execution_plan_custom_store_ops"] + store_op = store_op.for_each(custom_store_ops(workers, config)) def update_prio(item): samples, info_dict = item