Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional components #444

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from ...optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.awac_impl import AWACImpl
from .torch.sac_impl import SACModules
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import (
DeterministicContinuousActionSampler,
GaussianContinuousActionSampler,
)
from .torch.awac_impl import AWACActorLossFn
from .torch.ddpg_impl import DDPGValuePredictor
from .torch.sac_impl import SACCriticLossFn, SACModules, SACUpdater

__all__ = ["AWACConfig", "AWAC"]

Expand Down Expand Up @@ -97,7 +103,7 @@ def get_type() -> str:
return "awac"


class AWAC(QLearningAlgoBase[AWACImpl, AWACConfig]):
class AWAC(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, AWACConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -153,17 +159,44 @@ def inner_create_impl(
temp_optim=None,
)

self._impl = AWACImpl(
updater = SACUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=SACCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
policy=policy,
log_temp=dummy_log_temp,
gamma=self._config.gamma,
),
actor_loss_fn=AWACActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
n_action_samples=self._config.n_action_samples,
lam=self._config.lam,
action_size=action_size,
),
tau=self._config.tau,
compiled=self.compiled,
)
exploit_action_sampler = DeterministicContinuousActionSampler(policy)
explore_action_sampler = GaussianContinuousActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=exploit_action_sampler,
explore_action_sampler=explore_action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
119 changes: 96 additions & 23 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import VAELossFn
from .torch.bcq_impl import (
BCQImpl,
BCQActionSampler,
BCQActorLossFn,
BCQCriticLossFn,
BCQModules,
DiscreteBCQImpl,
BCQUpdater,
DiscreteBCQActionSampler,
DiscreteBCQLossFn,
DiscreteBCQModules,
)
from .torch.ddpg_impl import DDPGValuePredictor
from .torch.dqn_impl import DQNUpdater, DQNValuePredictor

__all__ = ["BCQConfig", "BCQ", "DiscreteBCQConfig", "DiscreteBCQ"]

Expand Down Expand Up @@ -171,7 +179,7 @@ def get_type() -> str:
return "bcq"


class BCQ(QLearningAlgoBase[BCQImpl, BCQConfig]):
class BCQ(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, BCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -257,20 +265,60 @@ def inner_create_impl(
vae_optim=vae_optim,
)

self._impl = BCQImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
updater = BCQUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
policy=policy,
targ_policy=targ_policy,
critic_optim=critic_optim,
actor_optim=actor_optim,
imitator_optim=vae_optim,
critic_loss_fn=BCQCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
targ_policy=targ_policy,
vae_decoder=vae_decoder,
gamma=self._config.gamma,
n_action_samples=self._config.n_action_samples,
lam=self._config.lam,
action_size=action_size,
),
actor_loss_fn=BCQActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
vae_decoder=vae_decoder,
action_size=action_size,
),
imitator_loss_fn=VAELossFn(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
kl_weight=self._config.beta,
),
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compiled=self.compiled,
)
action_sampler = BCQActionSampler(
policy=policy,
q_func_forwarder=q_func_forwarder,
vae_decoder=vae_decoder,
n_action_samples=self._config.n_action_samples,
action_size=action_size,
)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down Expand Up @@ -363,7 +411,9 @@ def get_type() -> str:
return "discrete_bcq"


class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]):
class DiscreteBCQ(
QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DiscreteBCQConfig]
):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -422,17 +472,40 @@ def inner_create_impl(
optim=optim,
)

self._impl = DiscreteBCQImpl(
# build functional components
updater = DQNUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=DiscreteBCQLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
imitator=imitator,
gamma=self._config.gamma,
beta=self._config.beta,
),
target_update_interval=self._config.target_update_interval,
compiled=self.compiled,
)
action_sampler = DiscreteBCQActionSampler(
q_func_forwarder=q_func_forwarder,
imitator=imitator,
action_flexibility=self._config.action_flexibility,
)
value_predictor = DQNValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=optim.optim,
policy=None,
policy_optim=None,
device=self._device,
)

Expand Down
95 changes: 78 additions & 17 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.bear_impl import BEARImpl, BEARModules
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import SquashedGaussianContinuousActionSampler, VAELossFn
from .torch.bear_impl import (
BEARActorLossFn,
BEARCriticLossFn,
BEARModules,
BEARSquashedGaussianContinuousActionSampler,
BEARUpdater,
BEARWarmupActorLossFn,
)
from .torch.ddpg_impl import DDPGValuePredictor

__all__ = ["BEARConfig", "BEAR"]

Expand Down Expand Up @@ -157,7 +167,7 @@ def get_type() -> str:
return "bear"


class BEAR(QLearningAlgoBase[BEARImpl, BEARConfig]):
class BEAR(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, BEARConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -259,24 +269,75 @@ def inner_create_impl(
alpha_optim=alpha_optim,
)

self._impl = BEARImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
updater = BEARUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
actor_optim=actor_optim,
imitator_optim=vae_optim,
critic_loss_fn=BEARCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
policy=policy,
log_temp=log_temp,
gamma=self._config.gamma,
n_target_samples=self._config.n_target_samples,
lam=self._config.lam,
),
actor_loss_fn=BEARActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
vae_decoder=vae_decoder,
log_temp=log_temp,
log_alpha=log_alpha,
temp_optim=temp_optim,
alpha_optim=alpha_optim,
n_mmd_action_samples=self._config.n_mmd_action_samples,
mmd_kernel=self._config.mmd_kernel,
mmd_sigma=self._config.mmd_sigma,
alpha_threshold=self._config.alpha_threshold,
action_size=action_size,
),
warmup_actor_loss_fn=BEARWarmupActorLossFn(
policy=policy,
vae_decoder=vae_decoder,
log_alpha=log_alpha,
action_size=action_size,
n_mmd_action_samples=self._config.n_mmd_action_samples,
mmd_kernel=self._config.mmd_kernel,
mmd_sigma=self._config.mmd_sigma,
alpha_threshold=self._config.alpha_threshold,
),
imitator_loss_fn=VAELossFn(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
kl_weight=self._config.vae_kl_weight,
),
tau=self._config.tau,
alpha_threshold=self._config.alpha_threshold,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
n_target_samples=self._config.n_target_samples,
n_mmd_action_samples=self._config.n_mmd_action_samples,
mmd_kernel=self._config.mmd_kernel,
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compiled=self.compiled,
)
exploit_action_sampler = BEARSquashedGaussianContinuousActionSampler(
policy=policy,
q_func_forwarder=q_func_forwarder,
n_action_samples=self._config.n_action_samples,
action_size=action_size,
)
explore_action_sampler = SquashedGaussianContinuousActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=exploit_action_sampler,
explore_action_sampler=explore_action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
Loading
Loading