diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..161cbd2c6 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,66 @@ +import io +from contextlib import redirect_stderr, redirect_stdout + +import jax +import jax.numpy as jnp +import jaxlib.xla_extension +import pytest + +from jaxsim import exceptions + + +def test_exceptions_in_jit_functions(capsys): + + msg_during_jit = "Compiling jit_compiled_function" + + @jax.jit + def jit_compiled_function(data: jax.Array) -> jax.Array: + + # This message is compiled only during JIT compilation. + print(msg_during_jit) + + # Condition that will trigger the exception. + failed_if_42 = jnp.allclose(data, 42) + + # Raise a ValueError if the condition is met. + exceptions.raise_value_error_if( + condition=failed_if_42, + msg="Raising ValueError since data={num}", + num=data, + ) + + # Condition that will trigger the exception. + failed_if_43 = jnp.allclose(data, 43) + + # Raise a RuntimeError if the condition is met. + exceptions.raise_runtime_error_if( + condition=failed_if_43, + msg="Raising RuntimeError since data={num}", + num=data, + ) + + return data + + # In the first call, the function will be compiled and print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 40 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit in stdout + assert jit_compiled_function._cache_size() == 1 + + # In the second call, the function won't be compiled and won't print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 41 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit not in stdout + assert jit_compiled_function._cache_size() == 1