Skip to content

Commit

Permalink
Merge pull request #181 from ami-iit/raise_exceptions_from_jit_compil…
Browse files Browse the repository at this point in the history
…ed_functions

Raise exceptions from jit-compiled functions
  • Loading branch information
diegoferigo committed Jun 18, 2024
2 parents 35f25b1 + 03fa934 commit 954c5c6
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
63 changes: 63 additions & 0 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import jax


def raise_if(
condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
) -> None:
"""
Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
Args:
condition:
The boolean condition of the evaluated expression that triggers
the exception during runtime.
exception: The type of exception to raise.
msg:
The message to display when the exception is raised. The message can be a
format string (fmt), whose fields are filled with the args and kwargs.
"""

# Check early that the format string is well-formed.
try:
_ = msg.format(*args, **kwargs)
except Exception as e:
msg = "Error in formatting exception message with args={} and kwargs={}"
raise ValueError(msg.format(args, kwargs)) from e

def _raise_exception(condition: bool, *args, **kwargs) -> None:
"""The function called by the JAX callback."""

if condition:
raise exception(msg.format(*args, **kwargs))

def _callback(args, kwargs) -> None:
"""The function that calls the JAX callback, executed only when needed."""

jax.debug.callback(_raise_exception, condition, *args, **kwargs)

# Since running a callable on the host is expensive, we prevent its execution
# if the condition is False with a low-level conditional expression.
def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
return jax.lax.cond(
condition,
_callback,
lambda args, kwargs: None,
args,
kwargs,
)

return _run_callback_only_if_condition_is_true(*args, **kwargs)


def raise_runtime_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:

return raise_if(condition, RuntimeError, msg, *args, **kwargs)


def raise_value_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:

return raise_if(condition, ValueError, msg, *args, **kwargs)
88 changes: 88 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import io
from contextlib import redirect_stdout

import jax
import jax.numpy as jnp
import jaxlib.xla_extension
import pytest

from jaxsim import exceptions


def test_exceptions_in_jit_functions():

msg_during_jit = "Compiling jit_compiled_function"

@jax.jit
def jit_compiled_function(data: jax.Array) -> jax.Array:

# This message is compiled only during JIT compilation.
print(msg_during_jit)

# Condition that will trigger the exception.
failed_if_42_plus = jnp.allclose(data, 42)

# Raise a ValueError if the condition is met.
# The fmt string is built from kwargs.
exceptions.raise_value_error_if(
condition=failed_if_42_plus,
msg="Raising ValueError since data={num}",
num=data,
)

# Condition that will trigger the exception.
failed_if_42_minus = jnp.allclose(data, -42)

# Raise a RuntimeError if the condition is met.
# The fmt string is built from args.
exceptions.raise_runtime_error_if(
failed_if_42_minus,
"Raising RuntimeError since data={}",
data,
)

return data

# In the first call, the function will be compiled and print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):

data = 40
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit in stdout
assert jit_compiled_function._cache_size() == 1

# In the second call, the function won't be compiled and won't print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):

data = 41
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit not in stdout
assert jit_compiled_function._cache_size() == 1

# Let's trigger a ValueError exception by passing 42.
# Note: the real ValueError is printed in a stream that I couldn't figure out
# how to capture in pytest.
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):

data = 42
_ = jit_compiled_function(data=data)

assert jit_compiled_function._cache_size() == 1

# Let's trigger a RuntimeError exception by passing -42.
# Note: the real RuntimeError is printed in a stream that I couldn't figure out
# how to capture in pytest.
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):

data = -42
_ = jit_compiled_function(data=data)

assert jit_compiled_function._cache_size() == 1

0 comments on commit 954c5c6

Please sign in to comment.