Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid to use as Static attributes classes that do not have a __eq__ method that returns a scalar bool #105

Merged
merged 2 commits into from
Mar 11, 2024

Conversation

traversaro
Copy link
Contributor

@traversaro traversaro commented Mar 10, 2024

Fix #103 .

Problem description

The problem reported in #103 (and I guess #84, even if I did not checked directly that one) is specifically caused by some Static attributes of jaxsim classes not having __eq__ methods that return objects that can be casted to bools. In particular, according to https://jax.readthedocs.io/en/latest/pytrees.html#pytrees any value that is returned as part of aux_data second return value of the tree_flatten method of a class passed to jax.tree_util.register_pytree_node_class must:

When defining an unflattening functions, in general children should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while aux_data should contain all the static elements that will be rolled into the treedef structure. JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons.

As it is made even more explicit in google/jax#19547 (comment) :

<...> aux_data must contain hashable static entries, that can be evaluated for equality using normal bool(a1 == a2).

The problem was triggered only in the second run of a jit function with the same instance, as that was the only time in which the program actually compared the value of static attributes.

Solution proposed in this PR

This condition is not respected in jaxsim before this PR. In this PR, I fixed the jaxsim classes to fix the minimal example provided in #103 . This is done in two ways:

  • For CollidablePoint, BoxCollision, SphereCollision, LinkDescription and RootPose, as these classes contained np.array or jnp.array attributes, I defined custom __eq__ methods to insert appropriatly the .all() method when comparing for equality arrays (commit: 1c34033)
  • For GroundContact the situation was a bit more complex, as in that case the problematic attribute was the body that was a np.array and was itself marked as Static attribute , and I could not re-define its __eq__ method to return a scalar bool. So, just for this case I decided to change the body attribute to be a list instead of nd.array .

Realistically, other classes are affected by the same problem, and I did not noticed them as the test reported in #103 is quite minimal (for example, no joint was involved). However, I think that for fixing those it is just a matter of having more complete tests (such as the one added in #102) and just iterating on those tests until no ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() appears anymore.

Requires #104 to be merged before.

@traversaro traversaro changed the title Fix103 Avoid to use as Static attributes classes that do not have a __eq__ method that returns a scalar bool Mar 10, 2024
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @traversaro! This should solve the longstanding issue we had.
I believe that if GroundContact.body does not need to be modified, we could prefer to make it a tuple instead of a list, what do you think @diegoferigo?

It LGTM anyway 🚀

src/jaxsim/physics/model/ground_contact.py Outdated Show resolved Hide resolved
src/jaxsim/physics/algos/soft_contacts.py Outdated Show resolved Hide resolved
@flferretti flferretti linked an issue Mar 11, 2024 that may be closed by this pull request
@flferretti
Copy link
Collaborator

I'm thinking to send a PR to jax_dataclasses to check this before JIT is applied

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem reported in #103 (and I guess #84, even if I did not checked directly that one) is specifically caused by some Static attributes of jaxsim classes not having __eq__ methods that return objects that can be casted to bools.

Wow thanks a lot for this investigation @traversaro, that's great! As mentioned in #103 (comment), I was suspecting something related to __hash__ or __eq__ was involved, but as first attempt I blamed the implementation of the pytree's class, not the one of its static attributes. Good to know! Super happy if we can finally have this solved. It simplifies prototyping in a Jupyter notebook (generally, in interactive iPython).

I'm thinking to send a PR to jax_dataclasses to check this before JIT is applied

I was about to suggest it, you anticipated me :) I would start investigating if we can raise an error in _flatten if not isinstance(treedef, collections.abc.Hashable) (see collections.abc.Hashable).

@traversaro
Copy link
Contributor Author

After sleeping on this, I guess that using Python introspection also the __eq__ implementation can be improved to avoid needing to manually enumerate all class attributes, that is error prone if a new attribute is added. However, this can be improvements in the future.

@flferretti
Copy link
Collaborator

flferretti commented Mar 11, 2024

After sleeping on this, I guess that using Python introspection also the __eq__ implementation can be improved to avoid needing to manually enumerate all class attributes, that is error prone if a new attribute is added. However, this can be improvements in the future.

What about comparing self.__dict__ == other.__dict__? Like:

def __eq__(self, other):
    if not isinstance(other, type(self)):
        return False

    return self.__dict__ == other.__dict__

@traversaro
Copy link
Contributor Author

After sleeping on this, I guess that using Python introspection also the __eq__ implementation can be improved to avoid needing to manually enumerate all class attributes, that is error prone if a new attribute is added. However, this can be improvements in the future.

What about comparing self.__dict__ == other.__dict__? Like:

def __eq__(self, other):
    if not isinstance(other, type(self)):
        return False

    return self.__dict__ == other.__dict__

Don't you get the same problem of ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() if the dictionary contains a np.array ? However, probably we can iterate over the dictonary's elements and use (a == b).all() if an element is a np.array or jnp.array (or similar).

@traversaro
Copy link
Contributor Author

I'm thinking to send a PR to jax_dataclasses to check this before JIT is applied

I was about to suggest it, you anticipated me :) I would start investigating if we can raise an error in _flatten if not isinstance(treedef, collections.abc.Hashable) (see collections.abc.Hashable).

Note that the problem here is not that the fact that the structure is hashable, but that it can be compare with bool(a1 == a2)

@diegoferigo
Copy link
Member

diegoferigo commented Mar 11, 2024

I'm thinking to send a PR to jax_dataclasses to check this before JIT is applied

I was about to suggest it, you anticipated me :) I would start investigating if we can raise an error in _flatten if not isinstance(treedef, collections.abc.Hashable) (see collections.abc.Hashable).

Note that the problem here is not that the fact that the structure is hashable, but that it can be compare with bool(a1 == a2)

Yes sorry for the confusion, I missed a step. I guess that if a class is hashable, its __eq__ method could just compare instances using it. What's important for our applications is that the hash does not include any object-specific data like its id, otherwise in our case different objects of the same model would trigger jit recompilations.

You can refer to Python data model for further details on the interaction between __hash__ and __eq__.

@diegoferigo
Copy link
Member

I don't like too much the idea of __dict__ since there might be custom pytrees implementations that do not have a compliant dictionary (or not have a __dict__ at all). I believe that jax_dataclasses should just check if Static attributes can be compared with __eq__, and raise an exception (or a warning if the author wants to preserve backward compatiblity) otherwise.

@traversaro
Copy link
Contributor Author

I don't like too much the idea of __dict__ since there might be custom pytrees implementations that do not have a compliant dictionary (or not have a __dict__ at all). I believe that jax_dataclasses should just check if Static attributes can be compared with __eq__, and raise an exception (or a warning if the author wants to preserve backward compatiblity) otherwise.

Ack, probably then we can merge the custom __eq__ until something better come along?

@diegoferigo
Copy link
Member

I don't like too much the idea of __dict__ since there might be custom pytrees implementations that do not have a compliant dictionary (or not have a __dict__ at all). I believe that jax_dataclasses should just check if Static attributes can be compared with __eq__, and raise an exception (or a warning if the author wants to preserve backward compatiblity) otherwise.

Ack, probably then we can merge the custom __eq__ until something better come along?

Definitely, feel free to merge this PR if it's ready to be merged.

@diegoferigo diegoferigo merged commit 4fd2032 into ami-iit:main Mar 11, 2024
11 checks passed
@traversaro traversaro deleted the fix103 branch March 11, 2024 11:17
@flferretti
Copy link
Collaborator

flferretti commented Mar 11, 2024

Regarding this, I made some additional tests and the problem doesn't seem to be related to Static attributes only. In fact, trying with:

Test script
import jax.numpy as jnp
import jaxsim.api as js
import rod.builder.primitives
import rod.urdf.exporter

rod_model = (
    rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
    .build_model()
    .add_link()
    .add_inertial()
    .add_visual()
    .add_collision()
    .build()
)

# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
    sdf=rod_model, pretty=True
)

model1 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_string,
    gravity=jnp.array([0, 0, -10]),
    is_urdf=True,
)

model2 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_string,
    gravity=jnp.array([0, 0, -10]),
    is_urdf=True,
)

data1 = js.data.JaxSimModelData.build(model=model1)
data2 = js.data.JaxSimModelData.build(model=model2)
data1 == data2

I still obtain:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

What solved the issue was to implement a JaxsimDataclass.__eq__:

    def __eq__(self, other):
        if self.__class__ is not other.__class__:
            return False

        return all(
            (key in other.__dict__)
            and (
                np.array_equal(self.__dict__[key], other.__dict__[key])
                if isinstance(self.__dict__[key], np.ndarray)
                or isinstance(self.__dict__[key], jnp.ndarray)
                else self.__dict__[key] == other.__dict__[key]
            )
            for key in self.__dict__
        )

and then inside JaxsimModelData:

def __eq__(self, other):
    return super().__eq__(other)

which actually I thought it wasn't necessary as:

>>> js.data.JaxSimModelData.__mro__
(<class 'jaxsim.api.data.JaxSimModelData'>, <class 'jaxsim.api.common.ModelDataWithVelocityRepresentation'>, <class 'jaxsim.utils.jaxsim_dataclass.JaxsimDataclass'>, <class 'abc.ABC'>, <class 'object'>)

Edit: the JaxsimDataclass.__eq__ wasn't called at all, still investigating on this

@diegoferigo
Copy link
Member

While you are investigating, if not necessary, I'd suggest to use dataclasses.fields to iterate on the attributes instead of using the low-level __dict__.

@traversaro
Copy link
Contributor Author

Edit: the JaxsimDataclass.__eq__ wasn't called at all, still investigating on this

Note that some dataclasses/jax_dataclasses decorator take in input a eq argument, in theory you can pass eq=False to avoid an __eq__ being generator. I did not found it was useful when fixing #103, but that may be related.

Regarding this, I made some additional tests and the problem doesn't seem to be related to Static attributes only.

Yes, but non-Static attributes do not the constraint that bool(a1 == a2) works when being passed to jitted functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants