-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_statistical_functions.py
102 lines (77 loc) · 3.1 KB
/
test_statistical_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import pytest
from packaging.version import parse
import ndonnx as ndx
AXES = [None, 0, 1, (0, 1), (1, 0), ()]
CORRECTIONS = [0, 1, 0.0, 1.0]
FLOAT_DTYPES = [np.float32, np.float64]
NUMERIC_DTYPES = (
FLOAT_DTYPES
+ [np.int8, np.int16, np.int32, np.int64]
+ [np.uint8, np.uint16, np.uint32, np.uint64]
)
ARRAYS = [
np.array([[-3, -1], [2, 3]]),
np.array([[-3], [-1], [2], [3]]),
np.array([[], [], [], []]),
]
if parse(np.__version__).major < 2:
pytest.skip(
reason="Statistical functions are not tested on NumPy 1 due to API incompatibilities",
allow_module_level=True,
)
def _compare_to_numpy(ndx_fun, np_fun, np_array, kwargs):
candidate = ndx_fun(ndx.asarray(np_array), **kwargs)
expectation = np_fun(np_array, **kwargs)
np.testing.assert_allclose(candidate.unwrap_numpy(), expectation)
@pytest.mark.parametrize(
"correction",
[
0,
],
) # 1, 0.0, 1.0])
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("axis", AXES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("np_arr", ARRAYS)
def test_std(dtype, correction, keepdims, axis, np_arr):
np_arr = np_arr.astype(dtype)
kwargs = {"correction": correction, "keepdims": keepdims, "axis": axis}
_compare_to_numpy(ndx.std, np.std, np_arr, kwargs)
@pytest.mark.parametrize("correction", [0, 1, 0.0, 1.0])
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("axis", AXES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("np_arr", ARRAYS)
def test_var(dtype, correction, keepdims, axis, np_arr):
np_arr = np_arr.astype(dtype)
kwargs = {"correction": correction, "keepdims": keepdims, "axis": axis}
_compare_to_numpy(ndx.var, np.var, np_arr, kwargs)
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("axis", AXES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("np_arr", ARRAYS)
def test_mean(dtype, keepdims, axis, np_arr):
np_arr = np_arr.astype(dtype)
kwargs = {"keepdims": keepdims, "axis": axis}
_compare_to_numpy(ndx.mean, np.mean, np_arr, kwargs)
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("axis", AXES)
@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)
@pytest.mark.parametrize("np_arr", ARRAYS)
def test_prod(dtype, keepdims, axis, np_arr):
# Take abs to avoid overflow issues with unsigned data types
np_arr = np.abs(np_arr).astype(dtype)
kwargs = {"keepdims": keepdims, "axis": axis}
_compare_to_numpy(ndx.prod, np.prod, np_arr, kwargs)
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("axis", AXES)
@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)
@pytest.mark.parametrize("np_arr", ARRAYS)
def test_sum(dtype, keepdims, axis, np_arr):
# Take abs to avoid overflow issues with unsigned data types
np_arr = np.abs(np_arr).astype(dtype)
kwargs = {"keepdims": keepdims, "axis": axis}
_compare_to_numpy(ndx.sum, np.sum, np_arr, kwargs)