Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility verify_infer_shape #1088

Open
ricardoV94 opened this issue Nov 13, 2024 · 0 comments
Open

Add utility verify_infer_shape #1088

ricardoV94 opened this issue Nov 13, 2024 · 0 comments
Labels

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 13, 2024

Description

There is a Test subclass that can be used for it, but requires running pytest which is not the best for someone implementing their Op. We could put that functionality in a verify_infer_shape function like verify_grad.

class InferShapeTester:
def setup_method(self):
# Take into account any mode that may be defined in a child class
# and it can be None
mode = getattr(self, "mode", None)
if mode is None:
mode = pytensor.compile.get_default_mode()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = mode.including("canonicalize")
def _compile_and_check(
self,
inputs,
outputs,
numeric_inputs,
cls,
excluding=None,
warn=True,
check_topo=True,
):
"""This tests the infer_shape method only
When testing with input values with shapes that take the same
value over different dimensions (for instance, a square
matrix, or a tensor3 with shape (n, n, n), or (m, n, m)), it
is not possible to detect if the output shape was computed
correctly, or if some shapes with the same value have been
mixed up. For instance, if the infer_shape uses the width of a
matrix instead of its height, then testing with only square
matrices will not detect the problem. If warn=True, we emit a
warning when testing with such values.
:param check_topo: If True, we check that the Op where removed
from the graph. False is useful to test not implemented case.
"""
mode = self.mode
if excluding:
mode = mode.excluding(*excluding)
if warn:
for var, inp in zip(inputs, numeric_inputs):
if isinstance(inp, int | float | list | tuple):
inp = var.type.filter(inp)
if not hasattr(inp, "shape"):
continue
# remove broadcasted dims as it is sure they can't be
# changed to prevent the same dim problem.
if hasattr(var.type, "broadcastable"):
shp = [
inp.shape[i]
for i in range(inp.ndim)
if not var.type.broadcastable[i]
]
else:
shp = inp.shape
if len(set(shp)) != len(shp):
_logger.warning(
"While testing shape inference for %r, we received an"
" input with a shape that has some repeated values: %r"
", like a square matrix. This makes it impossible to"
" check if the values for these dimensions have been"
" correctly used, or if they have been mixed up.",
cls,
inp.shape,
)
break
outputs_function = pytensor.function(inputs, outputs, mode=mode)
# Now that we have full shape information at the type level, it's
# possible/more likely that shape-computing graphs will not need the
# inputs to the graph for which the shape is computed
shapes_function = pytensor.function(
inputs, [o.shape for o in outputs], mode=mode, on_unused_input="ignore"
)
# Check that the Op is removed from the compiled function.
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert np.all(out.shape == shape), (out.shape, shape)

Probably could do with some cleaning as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant