Skip to content

Commit

Permalink
Add exceptions for hot loops
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica authored and ricardoV94 committed Nov 19, 2024
1 parent 54fba94 commit 4b41e09
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 48 deletions.
3 changes: 2 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,5 +863,6 @@ def clone(self):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables, strict=True):
# strict=False because asserted above
for output, variable in zip(outputs, variables, strict=False):
output[0] = variable
12 changes: 8 additions & 4 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,8 +1002,9 @@ def __call__(self, *args, **kwargs):
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
# strict=False because we are in a hot loop
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True
self.output_storage, self.maker.fgraph.outputs, strict=False
):
if o_variable.owner is not None:
# this node is the variable of computation
Expand All @@ -1012,8 +1013,9 @@ def __call__(self, *args, **kwargs):

if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
):
if input.update is not None:
storage.data = outputs.pop()
Expand Down Expand Up @@ -1044,7 +1046,8 @@ def __call__(self, *args, **kwargs):
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs, strict=True))
# strict=False because we are in a hot loop
return dict(zip(self.output_keys, outputs, strict=False))
else:
return {
self.output_keys[index]: outputs[index]
Expand Down Expand Up @@ -1111,8 +1114,9 @@ def _pickle_Function(f):
ins = list(f.input_storage)
input_storage = []

# strict=False because we are in a hot loop
for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults, strict=True
f.indices, f.defaults, strict=False
):
input_storage.append(ins[0])
del ins[0]
Expand Down
6 changes: 4 additions & 2 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def thunk():
if len(ls) > 0:
return ls
else:
for out, t in zip(outputs, input_true_branch, strict=True):
# strict=False because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False):
compute_map[out][0] = 1
val = storage_map[t][0]
if self.as_view:
Expand All @@ -325,7 +326,8 @@ def thunk():
if len(ls) > 0:
return ls
else:
for out, f in zip(outputs, inputs_false_branch, strict=True):
# strict=False because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False):
compute_map[out][0] = 1
# can't view both outputs unless destroyhandler
# improves
Expand Down
9 changes: 6 additions & 3 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,12 +539,14 @@ def make_thunk(self, **kwargs):

def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs, strict=True):
# strict=False because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
# strict=False because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
Expand Down Expand Up @@ -666,8 +668,9 @@ def thunk(
):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])

# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=True
fgraph.outputs, thunk_outputs, outputs, strict=False
):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
Expand Down
11 changes: 6 additions & 5 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,25 +1993,26 @@ def make_thunk(self, **kwargs):
)

def f():
for input1, input2 in zip(i1, i2, strict=True):
# strict=False because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip(
thunks1, thunks2, order1, order2, strict=True
thunks1, thunks2, order1, order2, strict=False
):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
if output in no_recycling:
storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
if output in no_recycling:
storage[0] = None
try:
thunk1()
thunk2()
for output1, output2 in zip(
thunk1.outputs, thunk2.outputs, strict=True
thunk1.outputs, thunk2.outputs, strict=False
):
self.checker(output1, output2)
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ def py_perform_return(inputs):
else:

def py_perform_return(inputs):
# strict=False because we are in a hot loop
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
)

@numba_njit
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def shape_i(x):
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape, strict=True):
# strict=False because asserted above
for actual, expected in zip(x.shape, shape, strict=False):
if expected is None:
continue
if actual != expected:
Expand Down
6 changes: 4 additions & 2 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
# strict=False because we are in a hot loop
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage, strict=True
thunks, order, post_thunk_old_storage, strict=False
):
thunk()
for old_s in old_storage:
Expand All @@ -206,7 +207,8 @@ def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order, strict=True):
# strict=False because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
Expand Down
6 changes: 4 additions & 2 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,8 +1150,9 @@ def perform(self, node, inputs, output_storage):
else:
variables = from_return_values(self.impl(*inputs))
assert len(variables) == len(output_storage)
# strict=False because we are in a hot loop
for out, storage, variable in zip(
node.outputs, output_storage, variables, strict=True
node.outputs, output_storage, variables, strict=False
):
dtype = out.dtype
storage[0] = self._cast_scalar(variable, dtype)
Expand Down Expand Up @@ -4328,7 +4329,8 @@ def make_node(self, *inputs):

def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs, strict=True):
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False):
storage[0] = out_val

def grad(self, inputs, output_grads):
Expand Down
5 changes: 3 additions & 2 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _validate_updates(
)
else:
update = outputs
for i, u in zip(init[: len(update)], update, strict=True):
for i, u in zip(init, update, strict=False):
if i.type != u.type:
raise TypeError(
"Init and update types must be the same: "
Expand Down Expand Up @@ -207,7 +207,8 @@ def perform(self, node, inputs, output_storage):
for i in range(n_steps):
carry = inner_fn(*carry, *constant)

for storage, out_val in zip(output_storage, carry, strict=True):
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False):
storage[0] = out_val

@property
Expand Down
3 changes: 2 additions & 1 deletion pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,9 @@ def __eq__(self, other):
if len(self.inner_outputs) != len(other.inner_outputs):
return False

# strict=False because length already compared above
for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=True
self.inner_inputs, other.inner_inputs, strict=False
):
if self_in.type != other_in.type:
return False
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3463,7 +3463,8 @@ def perform(self, node, inp, out):

# Make sure the output is big enough
out_s = []
for xdim, ydim in zip(x_s, y_s, strict=True):
# strict=False because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
Expand Down
10 changes: 6 additions & 4 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,17 @@ def core_func(
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)

# strict=False because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(
input.shape[:batch_ndim],
sinput.type.broadcastable[:batch_ndim],
strict=True,
strict=False,
)
for input, sinput in zip(inputs, node.inputs, strict=True)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=True,
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
Expand All @@ -374,8 +375,9 @@ def perform(self, node, inputs, output_storage):
if not isinstance(res, tuple):
res = (res,)

# strict=False because we are in a hot loop
for node_out, out_storage, r in zip(
node.outputs, output_storage, res, strict=True
node.outputs, output_storage, res, strict=False
):
out_dtype = getattr(node_out, "dtype", None)
if out_dtype and out_dtype != r.dtype:
Expand Down
8 changes: 5 additions & 3 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,9 @@ def perform(self, node, inputs, output_storage):
if nout == 1:
variables = [variables]

# strict=False because we are in a hot loop
for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs, strict=True)
zip(variables, output_storage, node.outputs, strict=False)
):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)

Expand All @@ -753,12 +754,13 @@ def perform(self, node, inputs, output_storage):

@staticmethod
def _check_runtime_broadcast(node, inputs):
# strict=False because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable, strict=False)
for input, sinput in zip(inputs, node.inputs, strict=True)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=True,
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,8 @@ def rng_fn(cls, rng, p, size):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`")
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True):
# strict=False because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`")

Expand Down
12 changes: 8 additions & 4 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def params_broadcast_shapes(
max_fn = maximum if use_pytensor else max

rev_extra_dims: list[int] = []
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True):
# strict=False because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
# We need this in order to use `len`
param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
Expand All @@ -63,11 +64,12 @@ def max_bcast(x, y):

extra_dims = tuple(reversed(rev_extra_dims))

# strict=False because we are in a hot loop
bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0
else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True)
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
]

return bcast_shapes
Expand Down Expand Up @@ -110,10 +112,11 @@ def broadcast_params(
use_pytensor = False
param_shapes = []
for p in params:
# strict=False because we are in a hot loop
param_shape = tuple(
1 if bcast else s
for s, bcast in zip(
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False
)
)
use_pytensor |= isinstance(p, Variable)
Expand All @@ -124,9 +127,10 @@ def broadcast_params(
)
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to

# strict=False because we are in a hot loop
bcast_params = [
broadcast_to_fn(param, shape)
for shape, param in zip(shapes, params, strict=True)
for shape, param in zip(shapes, params, strict=False)
]

return bcast_params
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
val_slices = []

for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if i >= n_added_dims:
Expand Down
14 changes: 6 additions & 8 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,9 @@ def perform(self, node, inp, out_):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
# strict=False because we are in a hot loop
if not all(
xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
):
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
Expand Down Expand Up @@ -578,15 +579,12 @@ def specify_shape(
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)

# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)

new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
if not new_shape_info and len(shape) == x.type.ndim:
return x

return _specify_shape(x, *shape)
Expand Down
Loading

0 comments on commit 4b41e09

Please sign in to comment.