Skip to content

Commit

Permalink
Add compiled property
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 7820873 commit 0f37065
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 133 deletions.
9 changes: 3 additions & 6 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class AWACConfig(LearnableConfig):
lam: float = 1.0
n_action_samples: int = 1
n_critics: int = 2
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -102,8 +101,6 @@ class AWAC(QLearningAlgoBase[AWACImpl, AWACConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -136,12 +133,12 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)

dummy_log_temp = Parameter(torch.zeros(1, 1))
Expand All @@ -166,7 +163,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
18 changes: 6 additions & 12 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ class BCQConfig(LearnableConfig):
action_flexibility: float = 0.05
rl_start_step: int = 0
beta: float = 0.5
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -176,8 +175,6 @@ class BCQ(QLearningAlgoBase[BCQImpl, BCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_deterministic_residual_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -234,18 +231,18 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=compiled,
compiled=self.compiled,
)

modules = BCQModules(
Expand Down Expand Up @@ -273,7 +270,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -355,7 +352,6 @@ class DiscreteBCQConfig(LearnableConfig):
beta: float = 0.5
target_update_interval: int = 8000
share_encoder: bool = True
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -371,8 +367,6 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
Expand Down Expand Up @@ -418,7 +412,7 @@ def inner_create_impl(
optim = self._config.optim_factory.create(
q_func_params + imitator_params,
lr=self._config.learning_rate,
compiled=compiled,
compiled=self.compiled,
)

modules = DiscreteBCQModules(
Expand All @@ -438,7 +432,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
15 changes: 6 additions & 9 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class BEARConfig(LearnableConfig):
mmd_sigma: float = 20.0
vae_kl_weight: float = 0.5
warmup_steps: int = 40000
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -162,8 +161,6 @@ class BEAR(QLearningAlgoBase[BEARImpl, BEARConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -223,28 +220,28 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)

modules = BEARModules(
Expand Down Expand Up @@ -279,7 +276,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
12 changes: 5 additions & 7 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -132,26 +130,26 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -181,7 +179,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
20 changes: 7 additions & 13 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ class CQLConfig(LearnableConfig):
n_action_samples: int = 10
soft_q_backup: bool = False
max_q_backup: bool = False
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -144,8 +143,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -187,26 +184,26 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -236,7 +233,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -296,7 +293,6 @@ class DiscreteCQLConfig(LearnableConfig):
n_critics: int = 1
target_update_interval: int = 8000
alpha: float = 1.0
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -312,8 +308,6 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
Expand All @@ -336,7 +330,7 @@ def inner_create_impl(
optim = self._config.optim_factory.create(
q_funcs.named_modules(),
lr=self._config.learning_rate,
compiled=compiled,
compiled=self.compiled,
)

modules = DQNModules(
Expand All @@ -354,7 +348,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
9 changes: 3 additions & 6 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class CRRConfig(LearnableConfig):
tau: float = 5e-3
target_update_interval: int = 100
update_actor_interval: int = 1
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand All @@ -137,8 +136,6 @@ class CRR(QLearningAlgoBase[CRRImpl, CRRConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -175,12 +172,12 @@ def inner_create_impl(
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
compiled=self.compiled,
)

modules = CRRModules(
Expand All @@ -207,7 +204,7 @@ def inner_create_impl(
tau=self._config.tau,
target_update_type=self._config.target_update_type,
target_update_interval=self._config.target_update_interval,
compile_graph=compiled,
compile_graph=self.compiled,
device=self._device,
)

Expand Down
Loading

0 comments on commit 0f37065

Please sign in to comment.