Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new functional APIs #88

Merged
merged 57 commits into from
Mar 1, 2024
Merged

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Feb 20, 2024

These new functional APIs (jaxsim.api) will soon replace those based on OOP (jaxsim.high_level).

  • Switch to kiss. The former OOP pattern, although more user friendly, fights back the functional JAX approach and required complex decorators to maintain and debug.
  • Easier for users to extend / modify the logic. Instead of changing something deep down in the framework, they just have to update / rewrite a single function.
  • Applying jax transforms like jax.vmap and jax.grad is more simple being more vanilla.

Furthermore, this PR:

  • Introduces new functions to operate on links and joints through their indices. Although I never liked this approach (and using names is still an option), having also functions operating on indices enables the use of jax.vmap to iterate on links and joints.
  • Introduces a new logic to get a good initial set of soft-contacts parameters based on the model inertial properties. In the past, tuning these parameters by hand has been quite difficult. The new logic provides a good starting point from which users can start iterating.
  • The existing fixed-step integrators have been rewritten from scratch. The new structure is now compatible with the introduction of variable-step integrators (in a future PR). So far, only the explicit schemes have been ported. This time, all of them with variants that integrate the quaternion on $\text{SO}(3)$.

Some implementation detail that might ease the review process and the transition to the new APIs:

  • The key data structures are now JaxSimModel and JaxSimModelData.
  • Most of API functions take them as first two arguments. In order to have more stable APIs, and allow us to change the signature of functions in future versions, all others arguments are enforced to be kwargs.
  • The velocity representation is now part of data (still as static attribute).
  • The soft-contacts parameters is now part of data. Therefore, in a vectorized scenario, there could be different parameters for each data instance, enabling domain randomization on these parameters.
  • The terrain instead is global and stored inside the model (therefore, it is not possible to parallelize a model over different terrains, that makes sense at least for now).

Next steps:

  • Port semi-implicit integration schemes to the new APIs.
  • Introduce variable-step integrators.
  • Merge jaxsim.api.model.JaxSimModel with jaxsim.physics.model.PhysicsModel.
  • Introduce a new pytree KynDynParameters to store inside JaxSimModel the model parameters, so that we can differentiate against them as well (we need also decent setters/getters).
  • Introduce error handling for link/joint indices out of bound, and data objects not compatible with model.

📚 Documentation preview 📚: https://jaxsim--88.org.readthedocs.build//88/

@diegoferigo
Copy link
Member Author

diegoferigo commented Feb 23, 2024

The MWE of the new functional APIs is the following script. Resulting video attached.

MWE
import pathlib

import jax.numpy as jnp
import jaxsim.api as js
import numpy as np
import resolve_robotics_uri_py
import rod
from jaxsim import VelRepr, integrators

# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)

# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
    gravity=jnp.array([0, 0, -10.0]),
)

# Reduce the model.
model = js.model.reduce(
    model=model,
    considered_joints=tuple(
        [
            j
            for j in model.joint_names()
            if "camera" not in j
            and "neck" not in j
            and "wrist" not in j
            and "thumb" not in j
            and "index" not in j
            and "middle" not in j
            and "ring" not in j
            and "pinkie" not in j
        ]
    ),
)

# Build the model's data.
# Set already here the initial base position.
data0 = js.data.JaxSimModelData.build(
    model=model,
    base_position=jnp.array([0, 0, 0.85]),
    velocity_representation=VelRepr.Inertial,
)

# Update the soft-contact parameters.
# By default, only 1 support point is used as worst-case scenario.
# Feel free to tune this with more points to get a less stiff system.
data0 = data0.replace(
    soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(
        model, number_of_active_collidable_points_steady_state=2
    )
)

# =====================
# Create the integrator
# =====================

# Create a RK4 integrator integrating the quaternion on SO(3).
integrator = integrators.fixed_step.RungeKutta4SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data0,
        system_dynamics=js.ode.system_dynamics,
    ),
)

# =========================================
# Visualization in Mujoco viewer / renderer
# =========================================

from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder, RodModelToMjcf

# Convert the ROD model to a Mujoco model.
mjcf_string, assets = RodModelToMjcf.convert(
    rod_model=rod_sdf.models()[0],
    considered_joints=list(model.joint_names()),
)

# Build the Mujoco model helper.
mj_model_helper = self = MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string, assets=assets
)

# Create the video recorder.
recorder = MujocoVideoRecorder(
    model=mj_model_helper.model,
    data=mj_model_helper.data,
    fps=int(1 / 0.010),
    width=320 * 4,
    height=240 * 4,
)

# ==============
# Recording loop
# ==============

# Initialize the integrator.
t0 = 0.0
tf = 5.0
dt = 0.001_000
integrator_state = integrator.init(x0=data0.state, t0=t0, dt=dt)

# Initialize the loop.
data = data0.copy()
joint_names = list(model.joint_names())

while data.time_ns < tf * 1e9:

    # Integrate the dynamics.
    data, integrator_state = js.model.step(
        dt=dt,
        model=model,
        data=data,
        integrator=integrator,
        integrator_state=integrator_state,
        # Optional inputs
        joint_forces=None,
        external_forces=None,
    )

    # Extract the generalized position.
    s = data.state.physics_model.joint_positions
    W_p_B = data.state.physics_model.base_position
    W_Q_B = data.state.physics_model.base_quaternion

    # Update the data object stored in the helper, which is shared with the recorder.
    mj_model_helper.set_base_position(position=np.array(W_p_B))
    mj_model_helper.set_base_orientation(orientation=np.array(W_Q_B), dcm=False)
    mj_model_helper.set_joint_positions(positions=np.array(s), joint_names=joint_names)

    # Record the frame if the time is right to get the desired fps.
    if data.time_ns % jnp.array(1e9 / recorder.fps).astype(jnp.uint64) == 0:
        recorder.render_frame(camera_name=None)

# Store the video.
video_path = pathlib.Path("~/video.mp4").expanduser()
recorder.write_video(path=video_path, exist_ok=True)

# Clean up the recorder.
recorder.frames = []
recorder.renderer.close()
Video

Resulting soft-contacts parameters:

SoftContactsParams(
    K=Array(465529.79648, dtype=float64),
    D=Array(10270.41887, dtype=float64),
    mu=Array(0.5, dtype=float64)
)
video.mp4

Some comment:

  • With the automatic computation of the soft-contact parameters, using the fixed-step RK4 integrator with quaternion on $\text{SO}(3)$ with $\Delta t = 0.001 s$ is enough to have a stable simulation.
  • Stepping the ergoCub robot, reduced to have 23 DoFs, takes:
    %timeit js.model.step(...)
    968 µs ± 137 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    
    obtaining therefore a real-time-factor of 1.

@diegoferigo diegoferigo force-pushed the functional_api branch 3 times, most recently from bd8b164 to 5965f1c Compare February 23, 2024 14:27
@diegoferigo diegoferigo marked this pull request as ready for review February 23, 2024 14:27
src/jaxsim/api/model.py Outdated Show resolved Hide resolved
src/jaxsim/api/data.py Show resolved Hide resolved
src/jaxsim/api/link.py Show resolved Hide resolved
src/jaxsim/integrators/common.py Show resolved Hide resolved
src/jaxsim/physics/algos/soft_contacts.py Show resolved Hide resolved
src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
src/jaxsim/api/data.py Outdated Show resolved Hide resolved
src/jaxsim/api/model.py Outdated Show resolved Hide resolved
@flferretti
Copy link
Collaborator

Thanks for working on this @diegoferigo, it looks pretty neat! Could you please add the new modules in the docs?

@diegoferigo diegoferigo force-pushed the functional_api branch 2 times, most recently from 8477d1b to ac1532c Compare February 23, 2024 16:54
@diegoferigo
Copy link
Member Author

Thanks for working on this @diegoferigo, it looks pretty neat! Could you please add the new modules in the docs?

I don't want to waste time now to fix the documentation, I'll do it as soon as things get more stable. I expect problem with math rendering etc, and I don't want to dedicate time to it at this early stage :) It has to be done before the next release.

@DanielePucci
Copy link
Member

I don't want to dedicate time to it at this early stage :)

Agreed!

Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Diego, LGTM

@flferretti
Copy link
Collaborator

flferretti commented Feb 27, 2024

I have a couple more comments:

  • If we want the old OOP structure to follow a deprecation cycle, I believe it can be worth adapting the tests in order to validate the new API. It's probably worth testing it anyway
  • It's not clear to me how to change the pose parameters. In the older API you could call e.g. model.reset_joint_positions(...), what should a user modify now, the JaxsimModelData using the replace method? In that case, if I'm not mistaken, it would be something like:
data.state.physics_model.replace(
    joint_positions=new_position
)

C.C. @traversaro

@diegoferigo
Copy link
Member Author

diegoferigo commented Feb 28, 2024

I have a couple more comments:

  • If we want the old OOP structure to follow a deprecation cycle, I believe it can be worth adapting the tests in order to validate the new API. It's probably worth testing it anyway

I already have local tests, but not yet ready. They will be added in a new PR.

  • It's not clear to me how to change the pose parameters. In the older API you could call e.g. model.reset_joint_positions(...), what should a user modify now, the JaxsimModelData using the replace method? In that case, if I'm not mistaken, it would be something like:

Although the new functional APIs will get merged into main, they are not yet 1:1 with the OOP ones. I'm still working on resources to simplify reset and actuation. Now you need to alter directly data.state as you found out.

@diegoferigo
Copy link
Member Author

@flferretti I just added JaxSimModelData.reset_* methods.

diegoferigo and others added 24 commits March 1, 2024 12:35
Co-authored-by: Filippo Luca Ferretti <[email protected]>
@diegoferigo diegoferigo merged commit b9160b7 into ami-iit:main Mar 1, 2024
10 checks passed
@diegoferigo diegoferigo deleted the functional_api branch March 1, 2024 11:45
@traversaro
Copy link
Contributor

Great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants