From 6114e6a0d3e46ba2ff89ff69942e9366e194cf41 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Aug 2021 11:57:29 -0700 Subject: [PATCH] test_util: add decorator to set config values in test cases --- jax/test_util.py | 22 +++++++++++++- tests/lax_numpy_einsum_test.py | 10 +------ tests/lax_numpy_indexing_test.py | 20 ++----------- tests/lax_numpy_test.py | 50 ++++--------------------------- tests/lax_numpy_vectorize_test.py | 10 +------ tests/lax_scipy_sparse_test.py | 10 +------ tests/lax_scipy_test.py | 10 +------ tests/random_test.py | 10 +------ tests/scipy_ndimage_test.py | 10 +------ tests/scipy_optimize_test.py | 20 ++----------- tests/scipy_signal_test.py | 10 +------ 11 files changed, 37 insertions(+), 145 deletions(-) diff --git a/jax/test_util.py b/jax/test_util.py index b9ce2008fbdd..2c7e349af60b 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import contextmanager +import inspect import functools import re import os @@ -870,8 +871,18 @@ def getTestCaseNames(self, testCaseClass): return names +def with_config(**kwds): + """Test case decorator for subclasses of JaxTestCase""" + def decorator(cls): + assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions." + cls._default_config = {**JaxTestCase._default_config, **kwds} + return cls + return decorator + + class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" + _default_config = {'jax_enable_checks': True} # TODO(mattjj): this obscures the error messages from failures, figure out how # to re-enable it @@ -880,12 +891,21 @@ class JaxTestCase(parameterized.TestCase): def setUp(self): super().setUp() - config.update('jax_enable_checks', True) + self._original_config = {} + for key, value in self._default_config.items(): + self._original_config[key] = getattr(config, key) + config.update(key, value) + # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) + def tearDown(self): + for key, value in self._original_config.items(): + config.update(key, value) + super().tearDown() + def rng(self): return self._rng diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index a75b6d549d34..c5a3fc18eef5 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -30,17 +30,9 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class EinsumTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def _check(self, s, *ops): a = np.einsum(s, *ops) b = jnp.einsum(s, *ops, precision=lax.Precision.HIGHEST) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 13c1b40c8517..cda788f78fe4 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -425,18 +425,10 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): np.array([[1, 0], [1, 0]]))), ]),] +@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string( shape, dtype), indexer), @@ -947,17 +939,9 @@ def dtypes(op): else: return default_dtypes +@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexedUpdateTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ffa71c87ca52..385745a3a7d2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -501,18 +501,10 @@ def wrapper(*args, **kw): return wrapper +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): def f(): out = [rng(shape, dtype or jnp.float_) @@ -5492,17 +5484,9 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(jnp.sinc, [0.], 1), ] +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyGradTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( @@ -5605,17 +5589,9 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpySignaturesTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def testWrappedSignaturesMatch(self): """Test that jax.numpy function signatures match numpy.""" jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)} @@ -5732,17 +5708,9 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: yield arg_dtypes +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyUfuncTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters( {"testcase_name": f"_{name}_{','.join(arg_dtypes)}", "name": name, "arg_dtypes": arg_dtypes} @@ -5774,17 +5742,9 @@ def testUfuncInputTypes(self, name, arg_dtypes): # that jnp returns float32. e.g. np.cos(np.uint8(0)) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2) +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyDocTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def test_lax_numpy_docstrings(self): # Test that docstring wrapping & transformation didn't fail. diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 883acad01492..d0f115052d81 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -25,17 +25,9 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class VectorizeTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape), "left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape} diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 2a325eebd33c..5f083ad34ed0 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -63,17 +63,9 @@ def rand_sym_pos_def(rng, shape, dtype): return matrix @ matrix.T.conj() +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def _fetch_preconditioner(self, preconditioner, A, rng=None): """ Returns one of various preconditioning matrices depending on the identifier diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 16c3e1acf7bb..d7a22e536824 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -141,18 +141,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t ] +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] diff --git a/tests/random_test.py b/tests/random_test.py index 78bc0370537c..c87ba5e39989 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -44,17 +44,9 @@ int_dtypes = jtu.dtypes.all_integer uint_dtypes = jtu.dtypes.all_unsigned +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomTest(jtu.JaxTestCase): - def setUp(self): - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - super().setUp() - - def tearDown(self): - super().tearDown() - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev nitems = len(samples) diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index b5af02199362..7d9990c3bfaf 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -57,17 +57,9 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0): return result +@jtu.with_config(jax_numpy_rank_promotion="raise") class NdimageTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".format( jtu.format_shape_dtype_string(shape, dtype), diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index d0292d0ef28c..5ede71e1c873 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -64,17 +64,9 @@ def zakharovFromIndices(x, ii): return answer +@jtu.with_config(jax_numpy_rank_promotion="raise") class TestBFGS(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter), "maxiter": maxiter, "func_and_init": func_and_init} @@ -149,17 +141,9 @@ def f(x): jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS') +@jtu.with_config(jax_numpy_rank_promotion="raise") class TestLBFGS(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter), "maxiter": maxiter, "func_and_init": func_and_init} diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 5d25fbeb2610..54e55e904409 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -35,18 +35,10 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" - def setUp(self): - super().setUp() - self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion - config.update("jax_numpy_rank_promotion", "raise") - - def tearDown(self): - config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion) - super().tearDown() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( op,