Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RBDA with jax.lax.scan #12

Closed
6 tasks done
diegoferigo opened this issue Sep 20, 2022 · 2 comments · Fixed by #38
Closed
6 tasks done

Implement RBDA with jax.lax.scan #12

diegoferigo opened this issue Sep 20, 2022 · 2 comments · Fixed by #38
Assignees

Comments

@diegoferigo
Copy link
Member

diegoferigo commented Sep 20, 2022

All our Rigid Body Dynamics Algorithms have been implemented either using plain for loops (that are unrolled during JIT compilation incurring in long build times) or jax.experimental.loops (that provided a nice syntactic sugar over the low-level jax.lax.{scan|fori_loop|while_loop}) that unfortunately have been removed in jax-ml/jax#11607 and no longer part of JAX starting from v0.3.16.

This issue tracks the activity of updating all the usage of the removed jax.experimental.loops. At this point, it seems wise to exploit jax.lax.scan throughout the algorithms. Readability will be definitely affected, but at least in this way we can also ensure that the code is forward and backward differentiable (#4).

@diegoferigo diegoferigo self-assigned this Sep 20, 2022
@diegoferigo
Copy link
Member Author

Reminder: as soon as we purge all jax.experimental.loops invocations, we can remove the following pinning:

jax < 0.3.16

@traversaro
Copy link
Contributor

cc @fl-ferr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants