Skip to content

Commit

Permalink
Add a strict argument to all zips
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica authored and ricardoV94 committed Nov 19, 2024
1 parent 6de3151 commit 19dafe4
Show file tree
Hide file tree
Showing 106 changed files with 769 additions and 481 deletions.
32 changes: 21 additions & 11 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def infer_shape(outs, inputs, input_shapes):
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes):
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))

# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp)

def local_traverse(out):
Expand Down Expand Up @@ -108,7 +108,9 @@ def construct_nominal_fgraph(

replacements = dict(
zip(
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
inputs + implicit_shared_inputs,
dummy_inputs + dummy_implicit_shared_inputs,
strict=True,
)
)

Expand Down Expand Up @@ -138,7 +140,7 @@ def construct_nominal_fgraph(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))

for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
Expand Down Expand Up @@ -562,7 +564,9 @@ def lop_overrides(inps, grads):
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt = [
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
lin
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
if gov is None
]
default_input_grads = fn_grad(wrt=wrt) if wrt else []
input_grads = self._combine_list_overrides(
Expand Down Expand Up @@ -653,7 +657,7 @@ def _build_and_cache_rop_op(self):
f = [
output
for output, custom_output_grad in zip(
inner_outputs, custom_output_grads
inner_outputs, custom_output_grads, strict=True
)
if custom_output_grad is None
]
Expand Down Expand Up @@ -733,18 +737,24 @@ def make_node(self, *inputs):

non_shared_inputs = [
inp_t.filter_variable(inp)
for inp, inp_t in zip(non_shared_inputs, self.input_types)
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
]

new_shared_inputs = inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
inner_and_input_shareds = list(
zip(self.shared_inputs, new_shared_inputs, strict=True)
)

if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
replace = dict(
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
zip(
self.inner_inputs[num_expected_inps:],
new_shared_inputs,
strict=True,
)
)

# If the new shared variables are inconsistent with the inner-graph,
Expand Down Expand Up @@ -811,7 +821,7 @@ def infer_shape(self, fgraph, node, shapes):
# each shape call. PyTensor optimizer will clean this up later, but this
# will make extra work for the optimizer.

repl = dict(zip(self.inner_inputs, node.inputs))
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
ret = []
Expand Down Expand Up @@ -853,5 +863,5 @@ def clone(self):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
for output, variable in zip(outputs, variables, strict=True):
output[0] = variable
16 changes: 9 additions & 7 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def _get_preallocated_maps(
# except if broadcastable, or for dimensions above
# config.DebugMode__check_preallocated_output_ndim
buf_shape = []
for s, b in zip(r_vals[r].shape, r.broadcastable):
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
if b or ((r.ndim - len(buf_shape)) > check_ndim):
buf_shape.append(s)
else:
Expand Down Expand Up @@ -943,7 +943,7 @@ def _get_preallocated_maps(
r_shape_diff = shape_diff[: r.ndim]
new_buf_shape = [
max((s + sd), 0)
for s, sd in zip(r_vals[r].shape, r_shape_diff)
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
]
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
Expand Down Expand Up @@ -1575,7 +1575,7 @@ def f():
# try:
# compute the value of all variables
for i, (thunk_py, thunk_c, node) in enumerate(
zip(thunks_py, thunks_c, order)
zip(thunks_py, thunks_c, order, strict=True)
):
_logger.debug(f"{i} - starting node {i} {node}")

Expand Down Expand Up @@ -1855,7 +1855,7 @@ def thunk():
assert s[0] is None

# store our output variables to their respective storage lists
for output, storage in zip(fgraph.outputs, output_storage):
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
storage[0] = r_vals[output]

# transfer all inputs back to their respective storage lists
Expand Down Expand Up @@ -1931,11 +1931,11 @@ def deco():
f,
[
Container(input, storage, readonly=False)
for input, storage in zip(fgraph.inputs, input_storage)
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
],
thunks_py,
order,
Expand Down Expand Up @@ -2122,7 +2122,9 @@ def __init__(

no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
for output, spec in zip(
fgraph.outputs, outputs + additional_outputs, strict=True
)
if not spec.borrow
]
if no_borrow:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(

new_inputs = []

for i, iv in zip(inputs, input_variables):
for i, iv in zip(inputs, input_variables, strict=True):
new_i = copy(i)
new_i.variable = iv

Expand Down Expand Up @@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)

for fg_inp, inp in zip(fgraph.inputs, inputs):
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)

for fg_out, out in zip(fgraph.outputs, outputs):
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
Expand Down
34 changes: 20 additions & 14 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def std_fgraph(
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs)
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
Expand Down Expand Up @@ -442,7 +442,7 @@ def __init__(
# this loop works by modifying the elements (as variable c) of
# self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
zip(self.indices, defaults)
zip(self.indices, defaults, strict=True)
):
if indices is None:
# containers is being used as a stack. Here we pop off
Expand Down Expand Up @@ -671,7 +671,7 @@ def checkSV(sv_ori, sv_rpl):
else:
outs = list(map(SymbolicOutput, fg_cpy.outputs))

for out_ori, out_cpy in zip(maker.outputs, outs):
for out_ori, out_cpy in zip(maker.outputs, outs, strict=False):
out_cpy.borrow = out_ori.borrow

# swap SharedVariable
Expand All @@ -684,7 +684,7 @@ def checkSV(sv_ori, sv_rpl):
raise ValueError(f"SharedVariable: {sv.name} not found")

# Swap SharedVariable in fgraph and In instances
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
# Variables in maker.inputs are defined by user, therefore we
# use them to make comparison and do the mapping.
# Otherwise we don't touch them.
Expand All @@ -708,7 +708,7 @@ def checkSV(sv_ori, sv_rpl):

# Delete update if needed
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
inp.variable = in_var
if not delete_updates and inp.update is not None:
out_idx = rev_update_mapping[n]
Expand Down Expand Up @@ -768,7 +768,11 @@ def checkSV(sv_ori, sv_rpl):
).create(input_storage, storage_map=new_storage_map)

for in_ori, in_cpy, ori, cpy in zip(
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
maker.inputs,
f_cpy.maker.inputs,
self.input_storage,
f_cpy.input_storage,
strict=True,
):
# Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap
Expand Down Expand Up @@ -999,7 +1003,7 @@ def __call__(self, *args, **kwargs):
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
self.output_storage, self.maker.fgraph.outputs, strict=True
):
if o_variable.owner is not None:
# this node is the variable of computation
Expand All @@ -1009,7 +1013,7 @@ def __call__(self, *args, **kwargs):
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage))
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
):
if input.update is not None:
storage.data = outputs.pop()
Expand Down Expand Up @@ -1040,7 +1044,7 @@ def __call__(self, *args, **kwargs):
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs))
return dict(zip(self.output_keys, outputs, strict=True))
else:
return {
self.output_keys[index]: outputs[index]
Expand Down Expand Up @@ -1108,7 +1112,7 @@ def _pickle_Function(f):
input_storage = []

for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults
f.indices, f.defaults, strict=True
):
input_storage.append(ins[0])
del ins[0]
Expand Down Expand Up @@ -1150,7 +1154,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):

f = maker.create(input_storage)
assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data):
for container, x in zip(f.input_storage, inputs_data, strict=True):
assert (
(container.data is x)
or (isinstance(x, np.ndarray) and (container.data == x).all())
Expand Down Expand Up @@ -1184,7 +1188,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
reason = "insert_deepcopy"
updated_fgraph_inputs = {
fgraph_i
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
if getattr(i, "update", False)
}

Expand Down Expand Up @@ -1521,7 +1525,9 @@ def __init__(
# return the internal storage pointer.
no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + found_updates)
for output, spec in zip(
fgraph.outputs, outputs + found_updates, strict=True
)
if not spec.borrow
]

Expand Down Expand Up @@ -1590,7 +1596,7 @@ def create(self, input_storage=None, storage_map=None):
# defaults lists.
assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
zip(self.indices, input_storage)
zip(self.indices, input_storage, strict=True)
):
# Replace any default value given as a variable by its
# container. Note that this makes sense only in the
Expand Down
4 changes: 2 additions & 2 deletions pytensor/d3viz/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def format_map(m):
ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs))
h = format_map(zip(ext_inputs, int_inputs, strict=True))
pd_node.get_attributes()["subg_map_inputs"] = h

# Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs))
h = format_map(zip(int_outputs, ext_outputs, strict=True))
pd_node.get_attributes()["subg_map_outputs"] = h

return graph
Expand Down
Loading

0 comments on commit 19dafe4

Please sign in to comment.