-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_iter.py
43 lines (32 loc) · 1001 Bytes
/
test_iter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
import pytest
import ndonnx as ndx
def test_iter_for_loop():
n = 5
a = ndx.array(shape=(n,), dtype=ndx.int64)
for i, el in enumerate(a):
assert isinstance(el, ndx.Array)
assert i < n, "Iterated past the number of elements"
@pytest.mark.parametrize(
"arr",
[
ndx.asarray([1]),
ndx.asarray([[1], [2]]),
ndx.array(shape=(2,), dtype=ndx.int64),
ndx.array(shape=(2, 3), dtype=ndx.int64),
ndx.array(shape=(2, "N"), dtype=ndx.int64),
],
)
def test_create_iterators(arr):
it = iter(arr)
el = next(it)
assert el.ndim == arr.ndim - 1
assert el.shape == arr.shape[1:]
def test_0d_not_iterable():
scalar = ndx.array(shape=(), dtype=ndx.int64)
with pytest.raises(ValueError):
next(iter(scalar))
def test_raises_dynamic_dim():
with pytest.raises(ValueError):
iter(ndx.array(shape=("N",), dtype=ndx.int64))