Skip to content

Commit

Permalink
Merge pull request #7522 from jakevdp:super
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389018519
  • Loading branch information
jax authors committed Aug 5, 2021
2 parents 3a469d5 + 63a788b commit df69062
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ class ShapedArray(UnshapedArray):
array_abstraction_level = 1

def __init__(self, shape, dtype, weak_type=False, named_shape={}):
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
super().__init__(dtype, weak_type=weak_type)
self.shape = canonicalize_shape(shape)
self.named_shape = dict(named_shape)

Expand Down Expand Up @@ -1141,8 +1141,8 @@ class ConcreteArray(ShapedArray):
array_abstraction_level = 0

def __init__(self, val, weak_type=False):
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
super().__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != np.dtype('O'), val
Expand Down
2 changes: 1 addition & 1 deletion jax/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def wrap_init(f, params={}) -> WrappedFun:
class _CacheLocalContext(threading.local):

def __init__(self):
super(_CacheLocalContext, self).__init__()
super().__init__()
self.most_recent_entry = None


Expand Down
2 changes: 1 addition & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ class JaxTestCase(parameterized.TestCase):
# assert core.reset_trace_state()

def setUp(self):
super(JaxTestCase, self).setUp()
super().setUp()
config.update('jax_enable_checks', True)
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
Expand Down
4 changes: 2 additions & 2 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super(DLPackTest, self).setUp()
super().setUp()
if jtu.device_under_test() == "tpu":
self.skipTest("DLPack not supported on TPU")

Expand Down Expand Up @@ -194,7 +194,7 @@ def testJaxToTorch(self, shape, dtype):
class CudaArrayInterfaceTest(jtu.JaxTestCase):

def setUp(self):
super(CudaArrayInterfaceTest, self).setUp()
super().setUp()
if jtu.device_under_test() != "gpu":
self.skipTest("__cuda_array_interface__ is only supported on GPU")

Expand Down
2 changes: 1 addition & 1 deletion tests/custom_object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class AbstractSparseArray(core.ShapedArray):

def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
named_shape={}):
super(AbstractSparseArray, self).__init__(shape, dtype)
super().__init__(shape, dtype)
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
Expand Down
6 changes: 3 additions & 3 deletions tests/sharded_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
class ShardedJitTest(jtu.JaxTestCase):

def setUp(self):
super(ShardedJitTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest
if jtu.device_under_test() == "gpu":
Expand Down Expand Up @@ -279,7 +279,7 @@ def testCompilationCache(self):
class ShardedJitErrorsTest(jtu.JaxTestCase):

def setUp(self):
super(ShardedJitErrorsTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest

Expand Down Expand Up @@ -329,7 +329,7 @@ def f(x):
class PmapOfShardedJitTest(jtu.JaxTestCase):

def setUp(self):
super(PmapOfShardedJitTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest
if jtu.device_under_test() == "gpu":
Expand Down

0 comments on commit df69062

Please sign in to comment.