Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rllib/agents/cql/cql_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"alpha_upper_bound": 1.0,
# Lower bound for alpha value during the lagrangian constraint
"alpha_lower_bound": 0.0,
# custom replay buffer
"replay_buffer": None,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
65 changes: 32 additions & 33 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,40 +216,39 @@ def execution_plan(workers: WorkerSet,
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)

input_reader = workers.local_worker().input_reader
is_offline_training = isinstance(input_reader, InMemoryInputReader)
if is_offline_training:
# if we have an InMemoryInputReader, then we are in Offline Training
# which means that we don't need the sampling pipeline setup
for batch in input_reader.get_all():
local_replay_buffer.add_batch(batch)
if config.get("replay_buffer"):
local_replay_buffer = config.get("replay_buffer")
input_reader = workers.local_worker().input_reader
assert isinstance(input_reader, InMemoryInputReader)
local_replay_buffer = local_replay_buffer(config, prio_args, input_reader)
is_offline_training = True
else:
parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync")
num_async = config.get("parallel_rollouts_num_async")
# This could be set to None explicitly
if not num_async:
num_async = 1
rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async)

# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
if config.get("execution_plan_custom_store_ops"):
custom_store_ops = config["execution_plan_custom_store_ops"]
store_op = store_op.for_each(custom_store_ops(workers, config))
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)

parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync")
num_async = config.get("parallel_rollouts_num_async")
# This could be set to None explicitly
if not num_async:
num_async = 1
rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async)

# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
if config.get("execution_plan_custom_store_ops"):
custom_store_ops = config["execution_plan_custom_store_ops"]
store_op = store_op.for_each(custom_store_ops(workers, config))

def update_prio(item):
samples, info_dict = item
Expand Down