-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[wip] Add test for exceptions raised by jit-compiled functions
- Loading branch information
1 parent
e8ebf15
commit dcd6fff
Showing
1 changed file
with
66 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import io | ||
from contextlib import redirect_stdout, redirect_stderr | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jaxlib.xla_extension | ||
import pytest | ||
|
||
from jaxsim import exceptions | ||
|
||
|
||
def test_exceptions_in_jit_functions(capsys): | ||
|
||
msg_during_jit = "Compiling jit_compiled_function" | ||
|
||
@jax.jit | ||
def jit_compiled_function(data: jax.Array) -> jax.Array: | ||
|
||
# This message is compiled only during JIT compilation. | ||
print(msg_during_jit) | ||
|
||
# Condition that will trigger the exception. | ||
failed_if_42 = jnp.allclose(data, 42) | ||
|
||
# Raise a ValueError if the condition is met. | ||
exceptions.raise_value_error_if( | ||
condition=failed_if_42, | ||
msg="Raising ValueError since data={num}", | ||
num=data, | ||
) | ||
|
||
# Condition that will trigger the exception. | ||
failed_if_43 = jnp.allclose(data, 43) | ||
|
||
# Raise a RuntimeError if the condition is met. | ||
exceptions.raise_runtime_error_if( | ||
condition=failed_if_43, | ||
msg="Raising RuntimeError since data={num}", | ||
num=data, | ||
) | ||
|
||
return data | ||
|
||
# In the first call, the function will be compiled and print the message. | ||
with jax.log_compiles(): | ||
with io.StringIO() as buf, redirect_stdout(buf): | ||
|
||
data = 40 | ||
out = jit_compiled_function(data=data) | ||
stdout = buf.getvalue() | ||
assert out == data | ||
|
||
assert msg_during_jit in stdout | ||
assert jit_compiled_function._cache_size() == 1 | ||
|
||
# In the second call, the function won't be compiled and won't print the message. | ||
with jax.log_compiles(): | ||
with io.StringIO() as buf, redirect_stdout(buf): | ||
|
||
data = 41 | ||
out = jit_compiled_function(data=data) | ||
stdout = buf.getvalue() | ||
assert out == data | ||
|
||
assert msg_during_jit not in stdout | ||
assert jit_compiled_function._cache_size() == 1 |