Skip to content

Commit

Permalink
Add test for exceptions raised by jit-compiled functions
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 17, 2024
1 parent e8ebf15 commit bbc54a3
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import io
from contextlib import redirect_stdout

import jax
import jax.numpy as jnp
import jaxlib.xla_extension
import pytest

from jaxsim import exceptions


def test_exceptions_in_jit_functions():

msg_during_jit = "Compiling jit_compiled_function"

@jax.jit
def jit_compiled_function(data: jax.Array) -> jax.Array:

# This message is compiled only during JIT compilation.
print(msg_during_jit)

# Condition that will trigger the exception.
failed_if_42_plus = jnp.allclose(data, 42)

# Raise a ValueError if the condition is met.
# The fmt string is built from kwargs.
exceptions.raise_value_error_if(
condition=failed_if_42_plus,
msg="Raising ValueError since data={num}",
num=data,
)

# Condition that will trigger the exception.
failed_if_42_minus = jnp.allclose(data, -42)

# Raise a RuntimeError if the condition is met.
# The fmt string is built from args.
exceptions.raise_runtime_error_if(
failed_if_42_minus,
"Raising RuntimeError since data={}",
data,
)

return data

# In the first call, the function will be compiled and print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):

data = 40
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit in stdout
assert jit_compiled_function._cache_size() == 1

# In the second call, the function won't be compiled and won't print the message.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):

data = 41
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data

assert msg_during_jit not in stdout
assert jit_compiled_function._cache_size() == 1

# Let's trigger a ValueError exception by passing 42.
# Note: the real ValueError is printed in a stream that I couldn't figure out
# how to capture in pytest.
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):

data = 42
_ = jit_compiled_function(data=data)

assert jit_compiled_function._cache_size() == 1

# Let's trigger a RuntimeError exception by passing -42.
# Note: the real RuntimeError is printed in a stream that I couldn't figure out
# how to capture in pytest.
with pytest.raises(jaxlib.xla_extension.XlaRuntimeError):

data = -42
_ = jit_compiled_function(data=data)

assert jit_compiled_function._cache_size() == 1

0 comments on commit bbc54a3

Please sign in to comment.