Skip to content

Commit 7cee800

Browse files
committed
Add assert_not_both_not_none (fixes #393)
1 parent 8bceebc commit 7cee800

2 files changed

Lines changed: 23 additions & 0 deletions

File tree

chex/_src/asserts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from chex._src import asserts_internal as _ai
2727
from chex._src import pytypes
28+
from chex._src import asserts
2829
import jax
2930
from jax.experimental import checkify
3031
import jax.numpy as jnp
@@ -248,6 +249,20 @@ def assert_not_both_none(first: Any, second: Any) -> None:
248249
raise AssertionError(
249250
"At least one of the arguments must be different from `None`.")
250251

252+
@_static_assertion
253+
def assert_not_both_not_none(first: Any, second: Any) -> None:
254+
"""Checks that not both arguments are non-None.
255+
256+
Args:
257+
first: A first object.
258+
second: A second object.
259+
260+
Raises:
261+
AssertionError: If both ``first`` and ``second`` are not None.
262+
"""
263+
if first is not None and second is not None:
264+
raise AssertionError(
265+
"At most one of the arguments may be different from `None`.")
251266

252267
@_static_assertion
253268
def assert_exactly_one_is_none(first: Any, second: Any) -> None:

chex/_src/asserts_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,15 @@ def test_assert_equal_fail(self, first, second):
18601860
with self.assertRaises(AssertionError):
18611861
asserts.assert_equal(first, second)
18621862

1863+
class NotNoneAssertionsTest(parameterized.TestCase):
18631864

1865+
def test_assert_not_both_not_none(self):
1866+
asserts.assert_not_both_not_none(None, None)
1867+
asserts.assert_not_both_not_none(1, None)
1868+
asserts.assert_not_both_not_none(None, 1)
1869+
1870+
with self.assertRaises(AssertionError):
1871+
asserts.assert_not_both_not_none(1, 2)
18641872
class IsDivisibleTest(parameterized.TestCase):
18651873

18661874
def test_assert_is_divisible(self):

0 commit comments

Comments
 (0)