-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #181 from ami-iit/raise_exceptions_from_jit_compil…
…ed_functions Raise exceptions from jit-compiled functions
- Loading branch information
Showing
2 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import jax | ||
|
||
|
||
def raise_if( | ||
condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs | ||
) -> None: | ||
""" | ||
Raise a host-side exception if a condition is met. Useful in jit-compiled functions. | ||
Args: | ||
condition: | ||
The boolean condition of the evaluated expression that triggers | ||
the exception during runtime. | ||
exception: The type of exception to raise. | ||
msg: | ||
The message to display when the exception is raised. The message can be a | ||
format string (fmt), whose fields are filled with the args and kwargs. | ||
""" | ||
|
||
# Check early that the format string is well-formed. | ||
try: | ||
_ = msg.format(*args, **kwargs) | ||
except Exception as e: | ||
msg = "Error in formatting exception message with args={} and kwargs={}" | ||
raise ValueError(msg.format(args, kwargs)) from e | ||
|
||
def _raise_exception(condition: bool, *args, **kwargs) -> None: | ||
"""The function called by the JAX callback.""" | ||
|
||
if condition: | ||
raise exception(msg.format(*args, **kwargs)) | ||
|
||
def _callback(args, kwargs) -> None: | ||
"""The function that calls the JAX callback, executed only when needed.""" | ||
|
||
jax.debug.callback(_raise_exception, condition, *args, **kwargs) | ||
|
||
# Since running a callable on the host is expensive, we prevent its execution | ||
# if the condition is False with a low-level conditional expression. | ||
def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None: | ||
return jax.lax.cond( | ||
condition, | ||
_callback, | ||
lambda args, kwargs: None, | ||
args, | ||
kwargs, | ||
) | ||
|
||
return _run_callback_only_if_condition_is_true(*args, **kwargs) | ||
|
||
|
||
def raise_runtime_error_if( | ||
condition: bool | jax.Array, msg: str, *args, **kwargs | ||
) -> None: | ||
|
||
return raise_if(condition, RuntimeError, msg, *args, **kwargs) | ||
|
||
|
||
def raise_value_error_if( | ||
condition: bool | jax.Array, msg: str, *args, **kwargs | ||
) -> None: | ||
|
||
return raise_if(condition, ValueError, msg, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import io | ||
from contextlib import redirect_stdout | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jaxlib.xla_extension | ||
import pytest | ||
|
||
from jaxsim import exceptions | ||
|
||
|
||
def test_exceptions_in_jit_functions(): | ||
|
||
msg_during_jit = "Compiling jit_compiled_function" | ||
|
||
@jax.jit | ||
def jit_compiled_function(data: jax.Array) -> jax.Array: | ||
|
||
# This message is compiled only during JIT compilation. | ||
print(msg_during_jit) | ||
|
||
# Condition that will trigger the exception. | ||
failed_if_42_plus = jnp.allclose(data, 42) | ||
|
||
# Raise a ValueError if the condition is met. | ||
# The fmt string is built from kwargs. | ||
exceptions.raise_value_error_if( | ||
condition=failed_if_42_plus, | ||
msg="Raising ValueError since data={num}", | ||
num=data, | ||
) | ||
|
||
# Condition that will trigger the exception. | ||
failed_if_42_minus = jnp.allclose(data, -42) | ||
|
||
# Raise a RuntimeError if the condition is met. | ||
# The fmt string is built from args. | ||
exceptions.raise_runtime_error_if( | ||
failed_if_42_minus, | ||
"Raising RuntimeError since data={}", | ||
data, | ||
) | ||
|
||
return data | ||
|
||
# In the first call, the function will be compiled and print the message. | ||
with jax.log_compiles(): | ||
with io.StringIO() as buf, redirect_stdout(buf): | ||
|
||
data = 40 | ||
out = jit_compiled_function(data=data) | ||
stdout = buf.getvalue() | ||
assert out == data | ||
|
||
assert msg_during_jit in stdout | ||
assert jit_compiled_function._cache_size() == 1 | ||
|
||
# In the second call, the function won't be compiled and won't print the message. | ||
with jax.log_compiles(): | ||
with io.StringIO() as buf, redirect_stdout(buf): | ||
|
||
data = 41 | ||
out = jit_compiled_function(data=data) | ||
stdout = buf.getvalue() | ||
assert out == data | ||
|
||
assert msg_during_jit not in stdout | ||
assert jit_compiled_function._cache_size() == 1 | ||
|
||
# Let's trigger a ValueError exception by passing 42. | ||
# Note: the real ValueError is printed in a stream that I couldn't figure out | ||
# how to capture in pytest. | ||
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): | ||
|
||
data = 42 | ||
_ = jit_compiled_function(data=data) | ||
|
||
assert jit_compiled_function._cache_size() == 1 | ||
|
||
# Let's trigger a RuntimeError exception by passing -42. | ||
# Note: the real RuntimeError is printed in a stream that I couldn't figure out | ||
# how to capture in pytest. | ||
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError): | ||
|
||
data = -42 | ||
_ = jit_compiled_function(data=data) | ||
|
||
assert jit_compiled_function._cache_size() == 1 |