From 1903889f3ecd84f885e25df7c3412095348dea3b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 20 Feb 2024 10:57:32 +0100 Subject: [PATCH] Add model.free_floating_bias_forces function --- src/jaxsim/api/model.py | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 63bd75d70..83e7ded08 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -859,3 +859,63 @@ def free_floating_gravity_forces( external_forces=jnp.zeros(shape=(model.number_of_links(), 6)), ) ).astype(float) + + +@jax.jit +def free_floating_bias_forces( + model: JaxSimModel, data: js.data.JaxSimModelData +) -> jtp.Vector: + """ + Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})` + of the model. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + The free-floating bias forces of the model. + """ + + # Build a zeroed state + data_rnea = js.data.JaxSimModelData.zero(model=model) + + # Set the generalized position and generalized velocity + with data_rnea.mutable_context( + mutability=Mutability.MUTABLE, restore_after_exception=False + ): + + data_rnea.state.physics_model.base_position = ( + data.state.physics_model.base_position + ) + + data_rnea.state.physics_model.base_quaternion = ( + data.state.physics_model.base_quaternion + ) + + data_rnea.state.physics_model.joint_positions = ( + data.state.physics_model.joint_positions + ) + + data_rnea.state.physics_model.base_linear_velocity = ( + data.state.physics_model.base_linear_velocity + ) + + data_rnea.state.physics_model.base_angular_velocity = ( + data.state.physics_model.base_angular_velocity + ) + + data_rnea.state.physics_model.joint_velocities = ( + data.state.physics_model.joint_velocities + ) + + return jnp.hstack( + inverse_dynamics( + model=model, + data=data_rnea, + # Set zero inputs: + joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())), + base_acceleration=jnp.zeros(6), + external_forces=jnp.zeros(shape=(model.number_of_links(), 6)), + ) + ).astype(float)