Skip to content

Commit

Permalink
Update CQL
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 3, 2025
1 parent 532d87c commit 34f68d8
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 85 deletions.
50 changes: 40 additions & 10 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from ...optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.awac_impl import AWACImpl
from .torch.sac_impl import SACModules
from .torch.ddpg_impl import DDPGValuePredictor
from .torch.awac_impl import AWACActorLossFn
from .torch.sac_impl import SACModules, SACUpdater, SACCriticLossFn
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import DeterministicContinuousActionSampler, GaussianContinuousActionSampler

__all__ = ["AWACConfig", "AWAC"]

Expand Down Expand Up @@ -153,17 +156,44 @@ def inner_create_impl(
temp_optim=None,
)

self._impl = AWACImpl(
updater = SACUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=SACCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
policy=policy,
log_temp=dummy_log_temp,
gamma=self._config.gamma,
),
actor_loss_fn=AWACActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
n_action_samples=self._config.n_action_samples,
lam=self._config.lam,
action_size=action_size,
),
tau=self._config.tau,
compiled=self.compiled,
)
exploit_action_sampler = DeterministicContinuousActionSampler(policy)
explore_action_sampler = GaussianContinuousActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=exploit_action_sampler,
explore_action_sampler=explore_action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
5 changes: 3 additions & 2 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor, DDPGCriticLossFn, DDPGActorLossFn, DDPGUpdater, DDPGModules
from .torch.ddpg_impl import DDPGValuePredictor, DDPGCriticLossFn, DDPGActorLossFn, DDPGUpdater, DDPGModules
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import DeterministicContinuousActionSampler

__all__ = ["DDPGConfig", "DDPG"]

Expand Down Expand Up @@ -171,7 +172,7 @@ def inner_create_impl(
tau=self._config.tau,
compiled=self.compiled,
)
action_sampler = DDPGActionSampler(policy)
action_sampler = DeterministicContinuousActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
Expand Down
34 changes: 34 additions & 0 deletions d3rlpy/algos/qlearning/functional_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

from ...models.torch import Policy, build_squashed_gaussian_distribution, build_gaussian_distribution
from ...types import TorchObservation
from .functional import ActionSampler

__all__ = ["DeterministicContinuousActionSampler", "GaussianContinuousActionSampler", "SquashedGaussianContinuousActionSampler"]


class DeterministicContinuousActionSampler(ActionSampler):
def __init__(self, policy: Policy):
self._policy = policy

def __call__(self, x: TorchObservation) -> torch.Tensor:
action = self._policy(x)
return action.squashed_mu


class GaussianContinuousActionSampler(ActionSampler):
def __init__(self, policy: Policy):
self._policy = policy

def __call__(self, x: TorchObservation) -> torch.Tensor:
dist = build_gaussian_distribution(self._policy(x))
return dist.sample()


class SquashedGaussianContinuousActionSampler(ActionSampler):
def __init__(self, policy: Policy):
self._policy = policy

def __call__(self, x: TorchObservation) -> torch.Tensor:
dist = build_squashed_gaussian_distribution(self._policy(x))
return dist.sample()
8 changes: 4 additions & 4 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor
from .torch.ddpg_impl import DDPGValuePredictor
from .torch.sac_impl import (
DiscreteSACImpl,
DiscreteSACModules,
SACModules,
SACActionSampler,
SACCriticLossFn,
SACActorLossFn,
SACUpdater,
)
from .functional import FunctionalQLearningAlgoImplBase
from .functional_utils import SquashedGaussianContinuousActionSampler, DeterministicContinuousActionSampler

__all__ = ["SACConfig", "SAC", "DiscreteSACConfig", "DiscreteSAC"]

Expand Down Expand Up @@ -214,8 +214,8 @@ def inner_create_impl(
tau=self._config.tau,
compiled=self.compiled,
)
exploit_action_sampler = DDPGActionSampler(policy)
explore_action_sampler = SACActionSampler(policy)
exploit_action_sampler = DeterministicContinuousActionSampler(policy)
explore_action_sampler = SquashedGaussianContinuousActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
Expand Down
62 changes: 17 additions & 45 deletions d3rlpy/algos/qlearning/torch/awac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F

from ....models.torch import (
ActionOutput,
Policy,
ContinuousEnsembleQFunctionForwarder,
build_gaussian_distribution,
)
Expand All @@ -12,60 +12,36 @@
flatten_left_recursively,
get_batch_size,
)
from ....types import Shape, TorchObservation
from .sac_impl import SACActorLoss, SACImpl, SACModules
from ....types import TorchObservation
from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseActorLossFn, DDPGBaseActorLoss

__all__ = ["AWACImpl"]
__all__ = ["AWACActorLossFn"]


class AWACImpl(SACImpl):
_lam: float
_n_action_samples: int

class AWACActorLossFn(DDPGBaseActorLossFn):
def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: SACModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
lam: float,
policy: Policy,
n_action_samples: int,
compiled: bool,
device: str,
lam: float,
action_size: int,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
compiled=compiled,
device=device,
)
self._lam = lam
self._q_func_forwarder = q_func_forwarder
self._policy = policy
self._n_action_samples = n_action_samples
self._lam = lam
self._action_size = action_size

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> SACActorLoss:
def __call__(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss:
# compute log probability
action = self._policy(batch.observations)
dist = build_gaussian_distribution(action)
log_probs = dist.log_prob(batch.actions)
# compute exponential weight
weights = self._compute_weights(batch.observations, batch.actions)
loss = -(log_probs * weights).sum()
return SACActorLoss(
actor_loss=loss,
temp_loss=torch.tensor(
0.0, dtype=torch.float32, device=loss.device
),
temp=torch.tensor(0.0, dtype=torch.float32, device=loss.device),
)
return DDPGBaseActorLoss(actor_loss=loss)

def _compute_weights(
self, obs_t: TorchObservation, act_t: torch.Tensor
Expand All @@ -80,9 +56,9 @@ def _compute_weights(

# sample actions
# (batch_size * N, action_size)
dist = build_gaussian_distribution(self._modules.policy(obs_t))
dist = build_gaussian_distribution(self._policy(obs_t))
policy_actions = dist.sample_n(self._n_action_samples)
flat_actions = policy_actions.reshape(-1, self.action_size)
flat_actions = policy_actions.reshape(-1, self._action_size)

# repeat observation
# (batch_size, obs_size) -> (batch_size, N, obs_size)
Expand All @@ -104,7 +80,3 @@ def _compute_weights(
weights = F.softmax(adv_values / self._lam, dim=0).view(-1, 1)

return weights * adv_values.numel()

def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
dist = build_gaussian_distribution(self._modules.policy(x))
return dist.sample()
Loading

0 comments on commit 34f68d8

Please sign in to comment.