Skip to content

Commit b5ae862

Browse files
committed
fix tests and code
1 parent d4ca4c1 commit b5ae862

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

tests/xnp/update_on_condition_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def test_update_on_condition_condition_shape_mismatch_raises():
277277
indices = jnp.array([0, 1, 2], dtype=jnp.int32)
278278
condition = jnp.array([[True, False, True]], dtype=jnp.bool_)
279279

280-
with pytest.raises(ValueError, match="`condition` shape .* must match `true_values` shape"):
280+
with pytest.raises(ValueError, match="`condition` shape .* must match `indices` shape"):
281281
xnp.update_on_condition(original, indices, condition, 5.0)
282282

283283

xtructure/core/xtructure_numpy/array_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _update_array_on_condition(
3939
return original_array
4040

4141
indices_array = jnp.asarray(indices)
42+
if condition.shape != indices_array.shape:
43+
raise ValueError(
44+
f"`condition` shape {condition.shape} must match `indices` shape {indices_array.shape}."
45+
)
46+
4247
indices_array = jnp.reshape(indices_array, (num_updates,))
4348
index_dtype = indices_array.dtype
4449
invalid_index = jnp.array(original_array.shape[0], dtype=index_dtype)

0 commit comments

Comments
 (0)