From f82927ea11877de8b78048dfb2247465f023acf1 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 17 Jun 2024 13:27:19 +0200 Subject: [PATCH 1/3] Add jaxsim.exceptions module It contains utility functions to raise exceptions within jit-compiled functions --- src/jaxsim/exceptions.py | 55 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 src/jaxsim/exceptions.py diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py new file mode 100644 index 000000000..382f493ab --- /dev/null +++ b/src/jaxsim/exceptions.py @@ -0,0 +1,55 @@ +import jax + + +def raise_if( + condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs +) -> None: + """ + Raise a host-side exception if a condition is met. Useful in jit-compiled functions. + + Args: + condition: + The boolean condition of the evaluated expression that triggers + the exception during runtime. + exception: The type of exception to raise. + msg: + The message to display when the exception is raised. It can be a fmt string, + and users can pass additional arguments to format the string. + """ + + # Check early that the fmt string is well-formed. + _ = msg.format(*args, **kwargs) + + def _raise_exception(condition: bool, *args, **kwargs) -> None: + if condition: + raise exception(msg.format(*args, **kwargs)) + + def _callback(args, kwargs) -> None: + jax.debug.callback(_raise_exception, condition, *args, **kwargs) + + # Since running a callable on the host is expensive, we prevent its execution + # if the condition is False with a low-level conditional expression. + def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None: + return jax.lax.cond( + condition, + _callback, + lambda args, kwargs: None, + args, + kwargs, + ) + + return _run_callback_only_if_condition_is_true(*args, **kwargs) + + +def raise_runtime_error_if( + condition: bool | jax.Array, msg: str, *args, **kwargs +) -> None: + + return raise_if(condition, RuntimeError, msg, *args, **kwargs) + + +def raise_value_error_if( + condition: bool | jax.Array, msg: str, *args, **kwargs +) -> None: + + return raise_if(condition, ValueError, msg, *args, **kwargs) From 827e5d3113add8c1eddfb752dbe432f3b57f72e9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 17:34:32 +0200 Subject: [PATCH 2/3] Add test for exceptions raised by jit-compiled functions --- tests/test_exceptions.py | 88 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/test_exceptions.py diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..c2700e70f --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,88 @@ +import io +from contextlib import redirect_stdout + +import jax +import jax.numpy as jnp +import jaxlib.xla_extension +import pytest + +from jaxsim import exceptions + + +def test_exceptions_in_jit_functions(): + + msg_during_jit = "Compiling jit_compiled_function" + + @jax.jit + def jit_compiled_function(data: jax.Array) -> jax.Array: + + # This message is compiled only during JIT compilation. + print(msg_during_jit) + + # Condition that will trigger the exception. + failed_if_42_plus = jnp.allclose(data, 42) + + # Raise a ValueError if the condition is met. + # The fmt string is built from kwargs. + exceptions.raise_value_error_if( + condition=failed_if_42_plus, + msg="Raising ValueError since data={num}", + num=data, + ) + + # Condition that will trigger the exception. + failed_if_42_minus = jnp.allclose(data, -42) + + # Raise a RuntimeError if the condition is met. + # The fmt string is built from args. + exceptions.raise_runtime_error_if( + failed_if_42_minus, + "Raising RuntimeError since data={}", + data, + ) + + return data + + # In the first call, the function will be compiled and print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 40 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit in stdout + assert jit_compiled_function._cache_size() == 1 + + # In the second call, the function won't be compiled and won't print the message. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + + data = 41 + out = jit_compiled_function(data=data) + stdout = buf.getvalue() + assert out == data + + assert msg_during_jit not in stdout + assert jit_compiled_function._cache_size() == 1 + + # Let's trigger a ValueError exception by passing 42. + # Note: the real ValueError is printed in a stream that I couldn't figure out + # how to capture in pytest. + with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): + + data = 42 + _ = jit_compiled_function(data=data) + + assert jit_compiled_function._cache_size() == 1 + + # Let's trigger a RuntimeError exception by passing -42. + # Note: the real RuntimeError is printed in a stream that I couldn't figure out + # how to capture in pytest. + with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): + + data = -42 + _ = jit_compiled_function(data=data) + + assert jit_compiled_function._cache_size() == 1 From 03fa9341917621f8aa8fe27b0209b8cd689ee7a9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 18 Jun 2024 09:45:48 +0200 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Filippo Luca Ferretti --- src/jaxsim/exceptions.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 382f493ab..8b127494d 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -13,18 +13,26 @@ def raise_if( the exception during runtime. exception: The type of exception to raise. msg: - The message to display when the exception is raised. It can be a fmt string, - and users can pass additional arguments to format the string. + The message to display when the exception is raised. The message can be a + format string (fmt), whose fields are filled with the args and kwargs. """ - # Check early that the fmt string is well-formed. - _ = msg.format(*args, **kwargs) + # Check early that the format string is well-formed. + try: + _ = msg.format(*args, **kwargs) + except Exception as e: + msg = "Error in formatting exception message with args={} and kwargs={}" + raise ValueError(msg.format(args, kwargs)) from e def _raise_exception(condition: bool, *args, **kwargs) -> None: + """The function called by the JAX callback.""" + if condition: raise exception(msg.format(*args, **kwargs)) def _callback(args, kwargs) -> None: + """The function that calls the JAX callback, executed only when needed.""" + jax.debug.callback(_raise_exception, condition, *args, **kwargs) # Since running a callable on the host is expensive, we prevent its execution