First steps towards automatic differentiation of RBDAs #54
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR works around some problems I encountered when trying to make automatic differentiation (AD) run on our rigid-body dynamics algorithms (RBDAs).
This is not the first time I've tried this in the past, but every time there were
NaN
s generated somewhere in the code that were quite difficult to spot. I never had enough time to dedicate to dissect small sections and proceed with a pragmatic debugging campaign.Luckily, JAX development progressed considerably, and now thanks to the following features:
JAX_DEBUG_NANS=1
that points directly to the origin of any NaN ,JAX_DISABLE_JIT=1
that disabled JIT entirely also inside e.g.jax.lax.scan
, allowing to spot NaNs occurring there,I finally managed to make the tests against finite difference (with
jax.test_util.check_grad
) work, both for the first-order and second-order derivatives! 🎉 I guess we are one step closer to a fully differentiable simulator, and definitely towards #4.For the moment, this PR checks that JAX is able to differentiate through the following algorithms:
I still have to check properly our integrators in a future PR. I suspect that the integration of quaternion with Baumgarte stabilization might be suboptimal in these cases, but it can be simply solved by implementing the integration on SO(3). I'm considering giving it a try.