Skip to content

Commit

Permalink
Merge pull request #16 from zombie-einstein/update_dependencies
Browse files Browse the repository at this point in the history
Update dependencies
  • Loading branch information
zombie-einstein committed Apr 26, 2024
2 parents 4422d22 + 66b77b5 commit 2bdf54d
Show file tree
Hide file tree
Showing 3 changed files with 1,481 additions and 1,098 deletions.
6 changes: 4 additions & 2 deletions jax_ppo/lstm/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class _LSTMLayer(linen.Module):
)
@linen.compact
def __call__(self, carry, x):
return linen.OptimizedLSTMCell()(carry, x)
return linen.OptimizedLSTMCell(x.shape[0])(carry, x)


class RecurrentActorCritic(linen.Module):
Expand Down Expand Up @@ -68,7 +68,9 @@ def initialise_carry(

return tuple(
[
linen.OptimizedLSTMCell.initialize_carry(k, batch_dims, hidden_size)
linen.OptimizedLSTMCell(features=hidden_size).initialize_carry(
k, batch_dims + (hidden_size,)
)
for _ in range(n_layers)
]
)
Loading

0 comments on commit 2bdf54d

Please sign in to comment.