diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 1643d53..0a2cf1d 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -625,6 +625,10 @@ def assert_shape( match; if ``expected_shapes`` has wrong type; if shape of ``input`` does not match ``expected_shapes``. """ + if expected_shapes is Ellipsis: + raise AssertionError( + "Error in shape compatibility check: `...` must be wrapped in a tuple, " + "e.g. `(...,)` not `...`.") if not isinstance(expected_shapes, (list, tuple)): raise AssertionError( "Error in shape compatibility check: expected shapes should be a list "