diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py new file mode 100644 index 000000000..8b127494d --- /dev/null +++ b/src/jaxsim/exceptions.py @@ -0,0 +1,63 @@ +import jax + + +def raise_if( + condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs +) -> None: + """ + Raise a host-side exception if a condition is met. Useful in jit-compiled functions. + + Args: + condition: + The boolean condition of the evaluated expression that triggers + the exception during runtime. + exception: The type of exception to raise. + msg: + The message to display when the exception is raised. The message can be a + format string (fmt), whose fields are filled with the args and kwargs. + """ + + # Check early that the format string is well-formed. + try: + _ = msg.format(*args, **kwargs) + except Exception as e: + msg = "Error in formatting exception message with args={} and kwargs={}" + raise ValueError(msg.format(args, kwargs)) from e + + def _raise_exception(condition: bool, *args, **kwargs) -> None: + """The function called by the JAX callback.""" + + if condition: + raise exception(msg.format(*args, **kwargs)) + + def _callback(args, kwargs) -> None: + """The function that calls the JAX callback, executed only when needed.""" + + jax.debug.callback(_raise_exception, condition, *args, **kwargs) + + # Since running a callable on the host is expensive, we prevent its execution + # if the condition is False with a low-level conditional expression. + def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None: + return jax.lax.cond( + condition, + _callback, + lambda args, kwargs: None, + args, + kwargs, + ) + + return _run_callback_only_if_condition_is_true(*args, **kwargs) + + +def raise_runtime_error_if( + condition: bool | jax.Array, msg: str, *args, **kwargs +) -> None: + + return raise_if(condition, RuntimeError, msg, *args, **kwargs) + + +def raise_value_error_if( + condition: bool | jax.Array, msg: str, *args, **kwargs +) -> None: + + return raise_if(condition, ValueError, msg, *args, **kwargs) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..c2700e70f --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,88 @@ +import io +from contextlib import redirect_stdout + +import jax +import jax.numpy as jnp +import jaxlib.xla_extension +import pytest + +from jaxsim import exceptions + + +def test_exceptions_in_jit_functions(): + + msg_during_jit = "Compiling jit_compiled_function" + + @jax.jit + def jit_compiled_function(data: jax.Array) -> jax.Array: + + # This message is compiled only during JIT compilation. + print(msg_during_jit) + + # Condition that will trigger the exception. + failed_if_42_plus = jnp.allclose(data, 42) + + # Raise a ValueError if the condition is met. + # The fmt string is built from kwargs. + exceptions.raise_value_error_if( + condition=failed_if_42_plus, + msg="Raising ValueError since data={num}", + num=data, + ) + + # Condition that will trigger the exception. + failed_if_42_minus = jnp.allclose(data, -42) + + # Raise a RuntimeError if the condition is met. + # The fmt string is built from args. + exceptions.raise_runtime_error_if( + failed_if_42_minus, + "Raising RuntimeError since data={}", + data, + ) + + return data + + # In the first call, the function will be compiled and print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 40 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit in stdout + assert jit_compiled_function._cache_size() == 1 + + # In the second call, the function won't be compiled and won't print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 41 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit not in stdout + assert jit_compiled_function._cache_size() == 1 + + # Let's trigger a ValueError exception by passing 42. + # Note: the real ValueError is printed in a stream that I couldn't figure out + # how to capture in pytest. + with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): + + data = 42 + _ = jit_compiled_function(data=data) + + assert jit_compiled_function._cache_size() == 1 + + # Let's trigger a RuntimeError exception by passing -42. + # Note: the real RuntimeError is printed in a stream that I couldn't figure out + # how to capture in pytest. + with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): + + data = -42 + _ = jit_compiled_function(data=data) + + assert jit_compiled_function._cache_size() == 1