From 4b41e092e9dd06113e7d7d8af4e26d90a7f1962f Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Tue, 19 Nov 2024 09:31:24 +0100 Subject: [PATCH] Add exceptions for hot loops --- pytensor/compile/builders.py | 3 ++- pytensor/compile/function/types.py | 12 ++++++++---- pytensor/ifelse.py | 6 ++++-- pytensor/link/basic.py | 9 ++++++--- pytensor/link/c/basic.py | 11 ++++++----- pytensor/link/numba/dispatch/basic.py | 3 ++- pytensor/link/pytorch/dispatch/shape.py | 3 ++- pytensor/link/utils.py | 6 ++++-- pytensor/scalar/basic.py | 6 ++++-- pytensor/scalar/loop.py | 5 +++-- pytensor/scan/op.py | 3 ++- pytensor/tensor/basic.py | 3 ++- pytensor/tensor/blockwise.py | 10 ++++++---- pytensor/tensor/elemwise.py | 8 +++++--- pytensor/tensor/random/basic.py | 3 ++- pytensor/tensor/random/utils.py | 12 ++++++++---- pytensor/tensor/rewriting/subtensor.py | 2 +- pytensor/tensor/shape.py | 14 ++++++-------- pytensor/tensor/type.py | 6 ++++-- 19 files changed, 77 insertions(+), 48 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 9f3994c864..49baa3bb26 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -863,5 +863,6 @@ def clone(self): def perform(self, node, inputs, outputs): variables = self.fn(*inputs) assert len(variables) == len(outputs) - for output, variable in zip(outputs, variables, strict=True): + # strict=False because asserted above + for output, variable in zip(outputs, variables, strict=False): output[0] = variable diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index eafb9eed5c..53306d52dc 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1002,8 +1002,9 @@ def __call__(self, *args, **kwargs): # if we are allowing garbage collection, remove the # output reference from the internal storage cells if getattr(self.vm, "allow_gc", False): + # strict=False because we are in a hot loop for o_container, o_variable in zip( - self.output_storage, self.maker.fgraph.outputs, strict=True + self.output_storage, self.maker.fgraph.outputs, strict=False ): if o_variable.owner is not None: # this node is the variable of computation @@ -1012,8 +1013,9 @@ def __call__(self, *args, **kwargs): if getattr(self.vm, "need_update_inputs", True): # Update the inputs that have an update function + # strict=False because we are in a hot loop for input, storage in reversed( - list(zip(self.maker.expanded_inputs, input_storage, strict=True)) + list(zip(self.maker.expanded_inputs, input_storage, strict=False)) ): if input.update is not None: storage.data = outputs.pop() @@ -1044,7 +1046,8 @@ def __call__(self, *args, **kwargs): assert len(self.output_keys) == len(outputs) if output_subset is None: - return dict(zip(self.output_keys, outputs, strict=True)) + # strict=False because we are in a hot loop + return dict(zip(self.output_keys, outputs, strict=False)) else: return { self.output_keys[index]: outputs[index] @@ -1111,8 +1114,9 @@ def _pickle_Function(f): ins = list(f.input_storage) input_storage = [] + # strict=False because we are in a hot loop for (input, indices, inputs), (required, refeed, default) in zip( - f.indices, f.defaults, strict=True + f.indices, f.defaults, strict=False ): input_storage.append(ins[0]) del ins[0] diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index c15477a8e0..c458e5b296 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -305,7 +305,8 @@ def thunk(): if len(ls) > 0: return ls else: - for out, t in zip(outputs, input_true_branch, strict=True): + # strict=False because we are in a hot loop + for out, t in zip(outputs, input_true_branch, strict=False): compute_map[out][0] = 1 val = storage_map[t][0] if self.as_view: @@ -325,7 +326,8 @@ def thunk(): if len(ls) > 0: return ls else: - for out, f in zip(outputs, inputs_false_branch, strict=True): + # strict=False because we are in a hot loop + for out, f in zip(outputs, inputs_false_branch, strict=False): compute_map[out][0] = 1 # can't view both outputs unless destroyhandler # improves diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index ea069c51cf..daeaa5740f 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -539,12 +539,14 @@ def make_thunk(self, **kwargs): def f(): for inputs in input_lists[1:]: - for input1, input2 in zip(inputs0, inputs, strict=True): + # strict=False because we are in a hot loop + for input1, input2 in zip(inputs0, inputs, strict=False): input2.storage[0] = copy(input1.storage[0]) for x in to_reset: x[0] = None pre(self, [input.data for input in input_lists[0]], order, thunk_groups) - for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)): + # strict=False because we are in a hot loop + for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)): try: wrapper(self.fgraph, i, node, *thunks) except Exception: @@ -666,8 +668,9 @@ def thunk( ): outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) + # strict=False because we are in a hot loop for o_var, o_storage, o_val in zip( - fgraph.outputs, thunk_outputs, outputs, strict=True + fgraph.outputs, thunk_outputs, outputs, strict=False ): compute_map[o_var][0] = True o_storage[0] = self.output_filter(o_var, o_val) diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 6fb4c8378e..0b717c74a6 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1993,25 +1993,26 @@ def make_thunk(self, **kwargs): ) def f(): - for input1, input2 in zip(i1, i2, strict=True): + # strict=False because we are in a hot loop + for input1, input2 in zip(i1, i2, strict=False): # Set the inputs to be the same in both branches. # The copy is necessary in order for inplace ops not to # interfere. input2.storage[0] = copy(input1.storage[0]) for thunk1, thunk2, node1, node2 in zip( - thunks1, thunks2, order1, order2, strict=True + thunks1, thunks2, order1, order2, strict=False ): - for output, storage in zip(node1.outputs, thunk1.outputs, strict=True): + for output, storage in zip(node1.outputs, thunk1.outputs, strict=False): if output in no_recycling: storage[0] = None - for output, storage in zip(node2.outputs, thunk2.outputs, strict=True): + for output, storage in zip(node2.outputs, thunk2.outputs, strict=False): if output in no_recycling: storage[0] = None try: thunk1() thunk2() for output1, output2 in zip( - thunk1.outputs, thunk2.outputs, strict=True + thunk1.outputs, thunk2.outputs, strict=False ): self.checker(output1, output2) except Exception: diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index f30cf2cc80..8bf827b52f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -401,9 +401,10 @@ def py_perform_return(inputs): else: def py_perform_return(inputs): + # strict=False because we are in a hot loop return tuple( out_type.filter(out[0]) - for out_type, out in zip(output_types, py_perform(inputs), strict=True) + for out_type, out in zip(output_types, py_perform(inputs), strict=False) ) @numba_njit diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index bb06656c7b..f771ac7211 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -34,7 +34,8 @@ def shape_i(x): def pytorch_funcify_SpecifyShape(op, node, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - for actual, expected in zip(x.shape, shape, strict=True): + # strict=False because asserted above + for actual, expected in zip(x.shape, shape, strict=False): if expected is None: continue if actual != expected: diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 7f48edcfb6..69c36f160d 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -190,8 +190,9 @@ def streamline_default_f(): for x in no_recycling: x[0] = None try: + # strict=False because we are in a hot loop for thunk, node, old_storage in zip( - thunks, order, post_thunk_old_storage, strict=True + thunks, order, post_thunk_old_storage, strict=False ): thunk() for old_s in old_storage: @@ -206,7 +207,8 @@ def streamline_nice_errors_f(): for x in no_recycling: x[0] = None try: - for thunk, node in zip(thunks, order, strict=True): + # strict=False because we are in a hot loop + for thunk, node in zip(thunks, order, strict=False): thunk() except Exception: raise_with_op(fgraph, node, thunk) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 1b87c8bf25..bb2baf0636 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1150,8 +1150,9 @@ def perform(self, node, inputs, output_storage): else: variables = from_return_values(self.impl(*inputs)) assert len(variables) == len(output_storage) + # strict=False because we are in a hot loop for out, storage, variable in zip( - node.outputs, output_storage, variables, strict=True + node.outputs, output_storage, variables, strict=False ): dtype = out.dtype storage[0] = self._cast_scalar(variable, dtype) @@ -4328,7 +4329,8 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) - for storage, out_val in zip(output_storage, outputs, strict=True): + # strict=False because we are in a hot loop + for storage, out_val in zip(output_storage, outputs, strict=False): storage[0] = out_val def grad(self, inputs, output_grads): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 8d87b4a06f..0b59195722 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -93,7 +93,7 @@ def _validate_updates( ) else: update = outputs - for i, u in zip(init[: len(update)], update, strict=True): + for i, u in zip(init, update, strict=False): if i.type != u.type: raise TypeError( "Init and update types must be the same: " @@ -207,7 +207,8 @@ def perform(self, node, inputs, output_storage): for i in range(n_steps): carry = inner_fn(*carry, *constant) - for storage, out_val in zip(output_storage, carry, strict=True): + # strict=False because we are in a hot loop + for storage, out_val in zip(output_storage, carry, strict=False): storage[0] = out_val @property diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 3b80b04ec3..bfe04a94d7 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1278,8 +1278,9 @@ def __eq__(self, other): if len(self.inner_outputs) != len(other.inner_outputs): return False + # strict=False because length already compared above for self_in, other_in in zip( - self.inner_inputs, other.inner_inputs, strict=True + self.inner_inputs, other.inner_inputs, strict=False ): if self_in.type != other_in.type: return False diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 253d2f5b7d..cd874a2cc6 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3463,7 +3463,8 @@ def perform(self, node, inp, out): # Make sure the output is big enough out_s = [] - for xdim, ydim in zip(x_s, y_s, strict=True): + # strict=False because we are in a hot loop + for xdim, ydim in zip(x_s, y_s, strict=False): if xdim == ydim: outdim = xdim elif xdim == 1: diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 4c136dac91..1c3a221642 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -342,16 +342,17 @@ def core_func( def _check_runtime_broadcast(self, node, inputs): batch_ndim = self.batch_ndim(node) + # strict=False because we are in a hot loop for dims_and_bcast in zip( *[ zip( input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim], - strict=True, + strict=False, ) - for input, sinput in zip(inputs, node.inputs, strict=True) + for input, sinput in zip(inputs, node.inputs, strict=False) ], - strict=True, + strict=False, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( @@ -374,8 +375,9 @@ def perform(self, node, inputs, output_storage): if not isinstance(res, tuple): res = (res,) + # strict=False because we are in a hot loop for node_out, out_storage, r in zip( - node.outputs, output_storage, res, strict=True + node.outputs, output_storage, res, strict=False ): out_dtype = getattr(node_out, "dtype", None) if out_dtype and out_dtype != r.dtype: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 55c80c40cb..cb60427ba0 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -737,8 +737,9 @@ def perform(self, node, inputs, output_storage): if nout == 1: variables = [variables] + # strict=False because we are in a hot loop for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs, strict=True) + zip(variables, output_storage, node.outputs, strict=False) ): storage[0] = variable = np.asarray(variable, dtype=nout.dtype) @@ -753,12 +754,13 @@ def perform(self, node, inputs, output_storage): @staticmethod def _check_runtime_broadcast(node, inputs): + # strict=False because we are in a hot loop for dims_and_bcast in zip( *[ zip(input.shape, sinput.type.broadcastable, strict=False) - for input, sinput in zip(inputs, node.inputs, strict=True) + for input, sinput in zip(inputs, node.inputs, strict=False) ], - strict=True, + strict=False, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index d5e346a5bf..bebcad55be 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1862,7 +1862,8 @@ def rng_fn(cls, rng, p, size): # to `p.shape[:-1]` in the call to `vsearchsorted` below. if len(size) < (p.ndim - 1): raise ValueError("`size` is incompatible with the shape of `p`") - for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True): + # strict=False because we are in a hot loop + for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False): if s == 1 and ps != 1: raise ValueError("`size` is incompatible with the shape of `p`") diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 1bdb936bdf..23b4b50265 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -44,7 +44,8 @@ def params_broadcast_shapes( max_fn = maximum if use_pytensor else max rev_extra_dims: list[int] = [] - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True): + # strict=False because we are in a hot loop + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False): # We need this in order to use `len` param_shape = tuple(param_shape) extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) @@ -63,11 +64,12 @@ def max_bcast(x, y): extra_dims = tuple(reversed(rev_extra_dims)) + # strict=False because we are in a hot loop bcast_shapes = [ (extra_dims + tuple(param_shape)[-ndim_param:]) if ndim_param > 0 else extra_dims - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True) + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False) ] return bcast_shapes @@ -110,10 +112,11 @@ def broadcast_params( use_pytensor = False param_shapes = [] for p in params: + # strict=False because we are in a hot loop param_shape = tuple( 1 if bcast else s for s, bcast in zip( - p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True + p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False ) ) use_pytensor |= isinstance(p, Variable) @@ -124,9 +127,10 @@ def broadcast_params( ) broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to + # strict=False because we are in a hot loop bcast_params = [ broadcast_to_fn(param, shape) - for shape, param in zip(shapes, params, strict=True) + for shape, param in zip(shapes, params, strict=False) ] return bcast_params diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 5263e4ee4b..fd98eaf718 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node): # Slices to take from val val_slices = [] - for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)): + for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): # If val was not copied over that dim, # we need to take the appropriate subtensor on it. if i >= n_added_dims: diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index fcc7915632..a357f25672 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -448,8 +448,9 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) + # strict=False because we are in a hot loop if not all( - xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None + xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None ): raise AssertionError( f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." @@ -578,15 +579,12 @@ def specify_shape( x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore] # The above is a type error in Python 3.9 but not 3.12. # Thus we need to ignore unused-ignore on 3.12. + new_shape_info = any( + s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None + ) # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError - if len(shape) != x.type.ndim: - return _specify_shape(x, *shape) - - new_shape_matches = all( - s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None - ) - if new_shape_matches: + if not new_shape_info and len(shape) == x.type.ndim: return x return _specify_shape(x, *shape) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index e5b81691e0..0f99fa48aa 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -248,9 +248,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: " PyTensor C code does not support that.", ) + # strict=False because we are in a hot loop if not all( ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape, strict=True) + for ds, ts in zip(data.shape, self.shape, strict=False) ): raise TypeError( f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" @@ -319,6 +320,7 @@ def in_same_class(self, otype): return False def is_super(self, otype): + # strict=False because we are in a hot loop if ( isinstance(otype, type(self)) and otype.dtype == self.dtype @@ -327,7 +329,7 @@ def is_super(self, otype): # but not less and all( sb == ob or sb is None - for sb, ob in zip(self.shape, otype.shape, strict=True) + for sb, ob in zip(self.shape, otype.shape, strict=False) ) ): return True