Skip to content

Commit

Permalink
Rename compile_graph to compiled
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 03888be commit b41be5a
Show file tree
Hide file tree
Showing 30 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -432,7 +432,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -348,7 +348,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def inner_create_impl(
tau=self._config.tau,
target_update_type=self._config.target_update_type,
target_update_interval=self._config.target_update_interval,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
modules=modules,
gamma=self._config.gamma,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -220,7 +220,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_forwarder,
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def inner_create_impl(
expectile=self._config.expectile,
weight_temp=self._config.weight_temp,
max_weight=self._config.max_weight,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=1,
gamma=self._config.gamma,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -386,7 +386,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/rebrac.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def inner_create_impl(
actor_beta=self._config.actor_beta,
critic_beta=self._config.critic_beta,
update_actor_interval=self._config.update_actor_interval,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -362,7 +362,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def inner_create_impl(
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
update_actor_interval=self._config.update_actor_interval,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/td3_plus_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def inner_create_impl(
target_smoothing_clip=self._config.target_smoothing_clip,
alpha=self._config.alpha,
update_actor_interval=self._config.update_actor_interval,
compile_graph=self.compiled,
compiled=self.compiled,
device=self._device,
)

Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/awac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
tau: float,
lam: float,
n_action_samples: int,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -44,7 +44,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._lam = lam
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
action_flexibility: float,
beta: float,
rl_start_step: int,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -82,7 +82,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._lam = lam
Expand All @@ -92,7 +92,7 @@ def __init__(
self._rl_start_step = rl_start_step
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad)
if compile_graph
if compiled
else self.compute_imitator_grad
)

Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(
gamma: float,
action_flexibility: float,
beta: float,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -267,7 +267,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=target_update_interval,
gamma=gamma,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._action_flexibility = action_flexibility
Expand Down
8 changes: 4 additions & 4 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
mmd_sigma: float,
vae_kl_weight: float,
warmup_steps: int,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -103,7 +103,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._alpha_threshold = alpha_threshold
Expand All @@ -117,12 +117,12 @@ def __init__(
self._warmup_steps = warmup_steps
self._compute_warmup_actor_grad = (
CudaGraphWrapper(self.compute_warmup_actor_grad)
if compile_graph
if compiled
else self.compute_warmup_actor_grad
)
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad)
if compile_graph
if compiled
else self.compute_imitator_grad
)

Expand Down
8 changes: 4 additions & 4 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
n_action_samples: int,
soft_q_backup: bool,
max_q_backup: bool,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -71,7 +71,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._alpha_threshold = alpha_threshold
Expand Down Expand Up @@ -247,7 +247,7 @@ def __init__(
target_update_interval: int,
gamma: float,
alpha: float,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -258,7 +258,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=target_update_interval,
gamma=gamma,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._alpha = alpha
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/crr_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
tau: float,
target_update_type: str,
target_update_interval: int,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -66,7 +66,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
self._beta = beta
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -88,12 +88,12 @@ def __init__(
self._targ_q_func_forwarder = targ_q_func_forwarder
self._compute_critic_grad = (
CudaGraphWrapper(self.compute_critic_grad)
if compile_graph
if compiled
else self.compute_critic_grad
)
self._compute_actor_grad = (
CudaGraphWrapper(self.compute_actor_grad)
if compile_graph
if compiled
else self.compute_actor_grad
)
hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs)
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -211,7 +211,7 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compile_graph=compile_graph,
compiled=compiled,
device=device,
)
hard_sync(self._modules.targ_policy, self._modules.policy)
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
target_update_interval: int,
gamma: float,
compile_graph: bool,
compiled: bool,
device: str,
):
super().__init__(
Expand All @@ -65,7 +65,7 @@ def __init__(
self._target_update_interval = target_update_interval
self._compute_grad = (
CudaGraphWrapper(self.compute_grad)
if compile_graph
if compiled
else self.compute_grad
)
hard_sync(modules.targ_q_funcs, modules.q_funcs)
Expand Down
Loading

0 comments on commit b41be5a

Please sign in to comment.