Skip to content

Commit 1d28d0b

Browse files
committed
Add SimBaEncoderFactory
1 parent 3e57c75 commit 1d28d0b

File tree

4 files changed

+224
-0
lines changed

4 files changed

+224
-0
lines changed

Diff for: d3rlpy/models/encoders.py

+53
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
VectorEncoder,
1313
VectorEncoderWithAction,
1414
)
15+
from .torch.encoders import SimBaEncoder, SimBaEncoderWithAction
1516
from .utility import create_activation
1617

1718
__all__ = [
1819
"EncoderFactory",
1920
"PixelEncoderFactory",
2021
"VectorEncoderFactory",
2122
"DefaultEncoderFactory",
23+
"SimBaEncoderFactory",
2224
"register_encoder_factory",
2325
"make_encoder_field",
2426
]
@@ -263,6 +265,56 @@ def get_type() -> str:
263265
return "default"
264266

265267

268+
@dataclass()
269+
class SimBaEncoderFactory(EncoderFactory):
270+
"""SimBa encoder factory class.
271+
272+
This class implements SimBa encoder architecture.
273+
274+
References:
275+
* `Lee et al., SimBa: Simplicity Bias for Scaling Up Parameters in Deep
276+
Reinforcement Learning, <https://arxiv.org/abs/2410.09754>`_
277+
278+
Args:
279+
feature_size (int): Feature unit size.
280+
hidden_size (int): HIdden expansion layer unit size.
281+
n_blocks (int): Number of SimBa blocks.
282+
"""
283+
284+
feature_size: int = 256
285+
hidden_size: int = 1024
286+
n_blocks: int = 1
287+
288+
def create(self, observation_shape: Shape) -> SimBaEncoder:
289+
assert len(observation_shape) == 1
290+
return SimBaEncoder(
291+
observation_shape=cast_flat_shape(observation_shape),
292+
hidden_size=self.hidden_size,
293+
output_size=self.feature_size,
294+
n_blocks=self.n_blocks,
295+
)
296+
297+
def create_with_action(
298+
self,
299+
observation_shape: Shape,
300+
action_size: int,
301+
discrete_action: bool = False,
302+
) -> SimBaEncoderWithAction:
303+
assert len(observation_shape) == 1
304+
return SimBaEncoderWithAction(
305+
observation_shape=cast_flat_shape(observation_shape),
306+
action_size=action_size,
307+
hidden_size=self.hidden_size,
308+
output_size=self.feature_size,
309+
n_blocks=self.n_blocks,
310+
discrete_action=discrete_action,
311+
)
312+
313+
@staticmethod
314+
def get_type() -> str:
315+
return "simba"
316+
317+
266318
register_encoder_factory, make_encoder_field = generate_config_registration(
267319
EncoderFactory, lambda: DefaultEncoderFactory()
268320
)
@@ -271,3 +323,4 @@ def get_type() -> str:
271323
register_encoder_factory(VectorEncoderFactory)
272324
register_encoder_factory(PixelEncoderFactory)
273325
register_encoder_factory(DefaultEncoderFactory)
326+
register_encoder_factory(SimBaEncoderFactory)

Diff for: d3rlpy/models/torch/encoders.py

+68
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"PixelEncoderWithAction",
1616
"VectorEncoder",
1717
"VectorEncoderWithAction",
18+
"SimBaEncoder",
19+
"SimBaEncoderWithAction",
1820
"compute_output_size",
1921
]
2022

@@ -290,6 +292,72 @@ def forward(
290292
return self._layers(x)
291293

292294

295+
class SimBaBlock(nn.Module): # type: ignore
296+
def __init__(self, input_size: int, hidden_size: int, out_size: int):
297+
super().__init__()
298+
layers = [
299+
nn.LayerNorm(input_size),
300+
nn.Linear(input_size, hidden_size),
301+
nn.ReLU(),
302+
nn.Linear(hidden_size, out_size)
303+
]
304+
self._layers = nn.Sequential(*layers)
305+
306+
def forward(self, x: torch.Tensor) -> torch.Tensor:
307+
return x + self._layers(x)
308+
309+
310+
class SimBaEncoder(Encoder):
311+
def __init__(
312+
self,
313+
observation_shape: Sequence[int],
314+
hidden_size: int,
315+
output_size: int,
316+
n_blocks: int,
317+
):
318+
super().__init__()
319+
layers = [
320+
nn.Linear(observation_shape[0], output_size),
321+
*[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)],
322+
nn.LayerNorm(output_size),
323+
]
324+
self._layers = nn.Sequential(*layers)
325+
326+
def forward(self, x: TorchObservation) -> torch.Tensor:
327+
assert isinstance(x, torch.Tensor)
328+
return self._layers(x)
329+
330+
331+
class SimBaEncoderWithAction(EncoderWithAction):
332+
def __init__(
333+
self,
334+
observation_shape: Sequence[int],
335+
action_size: int,
336+
hidden_size: int,
337+
output_size: int,
338+
n_blocks: int,
339+
discrete_action: bool,
340+
):
341+
super().__init__()
342+
layers = [
343+
nn.Linear(observation_shape[0] + action_size, output_size),
344+
*[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)],
345+
nn.LayerNorm(output_size),
346+
]
347+
self._layers = nn.Sequential(*layers)
348+
self._action_size = action_size
349+
self._discrete_action = discrete_action
350+
351+
def forward(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor:
352+
assert isinstance(x, torch.Tensor)
353+
if self._discrete_action:
354+
action = F.one_hot(
355+
action.view(-1).long(), num_classes=self._action_size
356+
).float()
357+
h = torch.cat([x, action], dim=1)
358+
return self._layers(h)
359+
360+
293361
def compute_output_size(
294362
input_shapes: Sequence[Shape], encoder: nn.Module
295363
) -> int:

Diff for: tests/models/test_encoders.py

+33
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from d3rlpy.models.encoders import (
77
DefaultEncoderFactory,
88
PixelEncoderFactory,
9+
SimBaEncoderFactory,
910
VectorEncoderFactory,
1011
)
1112
from d3rlpy.models.torch.encoders import (
1213
PixelEncoder,
1314
PixelEncoderWithAction,
15+
SimBaEncoder,
16+
SimBaEncoderWithAction,
1417
VectorEncoder,
1518
VectorEncoderWithAction,
1619
)
@@ -104,3 +107,33 @@ def test_default_encoder_factory(
104107

105108
# check serization and deserialization
106109
DefaultEncoderFactory.deserialize(factory.serialize())
110+
111+
112+
@pytest.mark.parametrize("observation_shape", [(100,)])
113+
@pytest.mark.parametrize("action_size", [2])
114+
@pytest.mark.parametrize("discrete_action", [False, True])
115+
def test_simba_encoder_factory(
116+
observation_shape: Sequence[int],
117+
action_size: int,
118+
discrete_action: bool,
119+
) -> None:
120+
factory = SimBaEncoderFactory()
121+
122+
# test state encoder
123+
encoder = factory.create(observation_shape)
124+
assert isinstance(encoder, SimBaEncoder)
125+
126+
# test state-action encoder
127+
encoder = factory.create_with_action(
128+
observation_shape, action_size, discrete_action
129+
)
130+
assert isinstance(encoder, SimBaEncoderWithAction)
131+
assert encoder._discrete_action == discrete_action
132+
133+
assert factory.get_type() == "simba"
134+
135+
# check serization and deserialization
136+
new_factory = SimBaEncoderFactory.deserialize(factory.serialize())
137+
assert new_factory.hidden_size == factory.hidden_size
138+
assert new_factory.feature_size == factory.feature_size
139+
assert new_factory.n_blocks == factory.n_blocks

Diff for: tests/models/torch/test_encoders.py

+70
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from d3rlpy.models.torch.encoders import (
88
PixelEncoder,
99
PixelEncoderWithAction,
10+
SimBaEncoder,
11+
SimBaEncoderWithAction,
1012
VectorEncoder,
1113
VectorEncoderWithAction,
1214
)
@@ -212,3 +214,71 @@ def test_vector_encoder_with_action(
212214

213215
# check layer connection
214216
check_parameter_updates(encoder, (x, action))
217+
218+
219+
@pytest.mark.parametrize("observation_shape", [(100,)])
220+
@pytest.mark.parametrize("hidden_size", [128])
221+
@pytest.mark.parametrize("output_size", [256])
222+
@pytest.mark.parametrize("n_blocks", [2])
223+
@pytest.mark.parametrize("batch_size", [32])
224+
def test_simba_encoder(
225+
observation_shape: Sequence[int],
226+
hidden_size: int,
227+
output_size: int,
228+
n_blocks: int,
229+
batch_size: int
230+
) -> None:
231+
encoder = SimBaEncoder(
232+
observation_shape=observation_shape,
233+
hidden_size=hidden_size,
234+
output_size=output_size,
235+
n_blocks=n_blocks,
236+
)
237+
238+
x = torch.rand((batch_size, *observation_shape))
239+
y = encoder(x)
240+
241+
# check output shape
242+
assert y.shape == (batch_size, output_size)
243+
244+
# check layer connection
245+
check_parameter_updates(encoder, (x,))
246+
247+
248+
@pytest.mark.parametrize("observation_shape", [(100,)])
249+
@pytest.mark.parametrize("action_size", [2])
250+
@pytest.mark.parametrize("hidden_size", [128])
251+
@pytest.mark.parametrize("output_size", [256])
252+
@pytest.mark.parametrize("n_blocks", [2])
253+
@pytest.mark.parametrize("batch_size", [32])
254+
@pytest.mark.parametrize("discrete_action", [False, True])
255+
def test_simba_encoder_with_action(
256+
observation_shape: Sequence[int],
257+
action_size: int,
258+
hidden_size: int,
259+
output_size: int,
260+
n_blocks: int,
261+
batch_size: int,
262+
discrete_action: bool,
263+
) -> None:
264+
encoder = SimBaEncoderWithAction(
265+
observation_shape=observation_shape,
266+
action_size=action_size,
267+
hidden_size=hidden_size,
268+
output_size=output_size,
269+
n_blocks=n_blocks,
270+
discrete_action=discrete_action,
271+
)
272+
273+
x = torch.rand((batch_size, *observation_shape))
274+
if discrete_action:
275+
action = torch.randint(0, action_size, size=(batch_size, 1))
276+
else:
277+
action = torch.rand(batch_size, action_size)
278+
y = encoder(x, action)
279+
280+
# check output shape
281+
assert y.shape == (batch_size, output_size)
282+
283+
# check layer connection
284+
check_parameter_updates(encoder, (x, action))

0 commit comments

Comments
 (0)