Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def execution_plan(workers: WorkerSet,
# which means that we don't need the sampling pipeline setup
for batch in input_reader.get_all():
local_replay_buffer.add_batch(batch)
config["bc_iters"] = input_reader.total_iterations_count
workers.local_worker().policy_map['default_policy'].update_config(config)
else:
parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync")
num_async = config.get("parallel_rollouts_num_async")
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"normalize_actions": True,
# Number of iterations to perform in the Behavior Cloning Pretraining
"bc_iters": None,
# Number of epochs to perform in the Behavior Cloning Pretraining
"bc_epochs": 1,

# === Learning ===
# Disable setting done=True at end of episode. This should be set to True
Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,11 @@ def sac_actor_critic_loss(
# Should be True only for debugging purposes (e.g. test cases)!
deterministic = policy.config["_deterministic_loss"]
bc_iters = policy.config["bc_iters"]
bc_iters_const = (tf.constant(bc_iters, dtype=policy.global_step.dtype)
if bc_iters else None)
bc_iters_const = tf1.placeholder_with_default(
tf.constant(bc_iters, dtype=policy.global_step.dtype),
shape=None,
name="bc_iters_const")
policy.bc_iters_const = bc_iters_const

# Get the base model output from the train batch.
model_out_t, _ = model({
Expand Down
1 change: 1 addition & 0 deletions rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(self,
self.dist_class is not None:
self._log_likelihood = self.dist_class(
self._dist_inputs, self.model).logp(self._action_input)
self.bc_iters_const: Optional[tf.Tensor] = None

def variables(self):
"""Return the list of all savable variables for this policy."""
Expand Down