Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in loop shape inference #163

Open
MatejUrbanQC opened this issue Jul 19, 2024 · 13 comments
Open

Bugs in loop shape inference #163

MatejUrbanQC opened this issue Jul 19, 2024 · 13 comments

Comments

@MatejUrbanQC
Copy link

MatejUrbanQC commented Jul 19, 2024

(v17) Custom inference fails when inputs and outputs of the body have mismatched shapes.

The following code should generate a loop that produces a tensor [1 1 1 1 1] by appending 1 in each iteration, but the inferred shape is (1,).

def test_loop():
    import spox.opset.ai.onnx.v17 as op
    import numpy as np
    
    num_iters = op.const(1)
    v = op.const([], dtype=np.int64)

    result = op.loop(num_iters, v_initial=[v], body=lambda i, c, x: (c, op.concat([x,op.const([1])], axis=0)))[0]
    assert result.type.shape == (5,)

Result:

FAILED bug.py::test_loop - assert (1,) == (5,)

(v19+) Custom inference is missing completely

The fix from v17 is completely missing in all newer versions. The error can be reproduced using the following code

def test_loop2():
    import spox.opset.ai.onnx.v19 as op

    num_iters = op.const(1)
    v = op.const(1)

    result = op.loop(num_iters, v_initial=[v], body=lambda i, cond, x: (cond, x))[0]
    assert result.type.shape == ()

Result:

FAILED bug.py::test_loop2 - assert None == ()

Note: this test succeeds with v17

@cbourjau
Copy link
Collaborator

Thanks for reporting the issue! The point around version 19 is certainly a bug. We have custom shape and type inference logic for a few operators, but try to upstream them into onnx. It seems that Loop was not followed up on properly. We have custom logic for Loop@v16, but not v19, nor v21 (btw, you can see the versions of each operator here). Looking at the onnx source code, it appears that the semantics didn't change, except that Loop gained support for more types. We will duplicate the inference logic from the earlier version to make this work straight away. This time around, we should push this upstream, though.

I'm afraid I don't see the issue with the first example, though. Where would that 5 be coming from?

@MatejUrbanQC
Copy link
Author

I'm afraid I don't see the issue with the first example, though. Where would that 5 be coming from?

Sorry, made a typo in the test. It should be

num_iters = op.const(5)

@jbachurski
Copy link
Collaborator

I would avoid trying to resolve overly complex type inference problems by assuming that the result should be (5,). We can't reason about these things symbolically anyway - what Spox sees is just that Loop started with an accumulator with shape (0,) and after the iteration got (1,). The very special case in which we know the number of iterations would require actually running the loop that many times, which is a Bad Idea - we prefer to keep complexity amortized to the number of nodes built in the entire graph [this is not quite true I recall, but we do our best].

The bug here (that it actually says 1 and not something more sensible; I seem to have introduced it back in #64) is that it always takes the type post-iteration, but if that mismatches the one pre-iteration, then we should make the dimension dynamic instead.

I think the 'more correct' behaviour could be a more robust unification between the known input and output types:

  • (0,) and (1,) unifies into (None,)
  • (3, 5) and (3, 7) unifies into (3, None)
  • ('N', 'M') and ('N') unifies into None
    I think there may have been some utilities for this at some point in the codebase but I don't remember anymore - they were probably removed. Please note there's also sequence, map and optional types.

And preferably this approach should be up-streamed, but ONNX's control flow has always lacked in such things. I think it does have something for unification implemented though.

Workaround: where a loop accumulator is known to have dynamic shape across iterations, use _future.unsafe_cast to amend it to be dynamic (None). Then the type inference will carry that information through.
Note there are three different pieces of type information each in a different set of Vars, and without a complex type inference Spox is conservative about using and populating them: 1) input to the loop operator, which populates 2) the subgraph inputs for accumulators, which are used to infer 3) the subgraph outputs. What you want to do here is to give (1) a dynamic shape that will get carried through to (2) and (3).

Hope this helps!

@cbourjau
Copy link
Collaborator

I think there may also be an implicit assumption on the semantics of Loop. Namely, that the outputs of the subgraph have to have the same type and shape in each iteration. The documentation for Scan spells this requirement out explicitly:

All the output tensors (state_variables as well as scan_output_element tensors) are required to have the same shape in each iteration of the loop [...].

While Loop does not state so, we may have assumed so in the past.

@MatejUrbanQC
Copy link
Author

I found a comment in the inference code in onnx that says it can change https://github.com/onnx/onnx/blob/ea4afc74f8f84646973e132ed471a82586a8aee1/onnx/defs/controlflow/utils.cc#L96

@MatejUrbanQC
Copy link
Author

Another thing I realised. Since the loop constructor takes a function, you could pass a function, which makes a different graph depending on the static shape of the inputs. This would not have the effect you would expect intuitively (to be like repeatedly calling the function in a loop), because only the first iteration determines the graph.

Is this not a problem? Should the user just be careful about the function they pass in?

I think ideally the user should specify what shape info should the function be called with (or just provide the actual subgraph).

@jbachurski
Copy link
Collaborator

Yes, while Spox exposes the type information (effectively allowing ad-hoc polymorphism in the DSL, this is driven by the metalanguage: here, Python), the user has to be careful about what is done with that information.
Similarly, Spox makes use of no global or mutable state (which was something I strived for, hoping it would be as predictable as possible). The user using it could also break something.

Is this not a problem? Should the user just be careful about the function they pass in?

It is a problem, but what can be done about it? We can't stop this at the language level in Python, it's always up to the user in the end.

I think ideally the user should specify what shape info should the function be called with (or just provide the actual subgraph).

Specifying shape information is getting rid of type inference. If you like, you can implement a wrapper around unsafe_cast which would essentially do this. You only need to make sure that the types passed into Loop are as general as possible.
What do you mean by 'provide the actual subgraph'? As protobuf? What about Vars which are outside the scope of the subgraph? If you insisted on this sort of exactness, you could use inline for the loop iteration.

@MatejUrbanQC
Copy link
Author

I don't see how you would get rid of type inference. The difference is only in how you construct the subgraph. Instead of taking the input shapes from the initial inputs (which is like assuming they will be unchanged), you take them from the user. You would the unify the initial inputs, body inputs and body outputs to infer the loop output shapes.

By the actual subgraph I mean any representation of fixed subgraph. Something that has a list of arguments and a list of outputs.

@cbourjau
Copy link
Collaborator

We could also be stricter than the standard and immediately error if the type/shape information of the initial values differs from those we got after running the body Python function. This wouldn't be ideal, but in my humble opinion, the standard is broken in this aspect. If you want to concat outputs as in your example, you may simply put them into the scan_outputs (and get the concatenation for free) rather than the loop-carried values.

@MatejUrbanQC
Copy link
Author

@cbourjau you can't use scan_outputs if you are trying to concatenate arrays of different lengths.

@neNasko1
Copy link
Contributor

neNasko1 commented Jan 2, 2025

I have spent some time trying to implement different ideas for propagating the shape info, however most of them are either inconsistent or too slow.

Currently we have the problem that we cannot reason about the shapes of the state variables. ONNX explicitly gives up when propagating the state variable types. This exists similarly in spox, however up to now we were inferring the wrong types, when a shape-changing operation took place.

@jbachurski suggested we unite the initial types with the output types. This seems correct on a first pass, however the following example shows we need to iteratively do this inference procedure until we hit a fixed point:

initial_a: Int[1]
initial_b: Int[1]
initial_c: Int[1]

for ...:
	a, b, c = concat(a, a, axis=0), a, b

# propagation iter=0
# a: Int[1]
# b: Int[1]
# c: Int[1]

# propagation iter=1
# a: Int[1] U Int[2] = Int[None]
# b: Int[1] U Int[1] = Int[1]
# c: Int[1] U Int[1] = Int[1]

# propagation iter=2
# a: Int[None] U Int[None] = Int[None]
# b: Int[1] U Int[None] = Int[None]
# c: Int[1] U Int[1] = Int[1]

# propagation iter=3
# a: Int[None] U Int[None] = Int[None]
# b: Int[None] U Int[None] = Int[None]
# c: Int[1] U Int[None] = Int[None]

# propagation iter=4
# a: Int[None] U Int[None] = Int[None]
# b: Int[None] U Int[None] = Int[None]
# c: Int[None] U Int[None] = Int[None]
# Fixed point -> we can stop

One can see this procedure can be extended to be arbitrarly long, and when nested the loop-s may produce arbitrarly slow inferences. This is directly contradictory to

running the loop that many times, which is a Bad Idea - we prefer to keep complexity amortized to the number of nodes built in the entire graph [this is not quite true I recall, but we do our best].

In my opinion we should strive to hit the middle ground as mentioned in #198, where we try to do the inference once. If we are already at the fixed point(i.e. input and output types of the subgraph are the same) we use those, if however some type changes, give up and do not annotate any of the state variables.

@adityagoel4512
Copy link
Member

adityagoel4512 commented Jan 2, 2025

Thanks for writing this out. I definitely agree that predictable shape inference performance is too important to give up without a really compelling use case. Did you already do the work to construct a bound on the number of iterations with respect to the loop properties (I assume unite always yields the same shape or strictly "smaller")?

@neNasko1
Copy link
Contributor

neNasko1 commented Jan 2, 2025

Thanks for writing this out. I definitely agree that predictable shape inference performance is too important to give up without a really compelling use case. Did you already do the work to construct a bound on the number of iterations with respect to the loop properties (I assume unite always yields the same shape or strictly "smaller")?

The process strictly loses information, so it must terminate. I suspect an argument could be made for the number of iteration to be O(S+n), where n is the number of input variables and S is the sum of ranks of the inputs. Another useful example to consider may be the loop:

a: Int[][][][]

for ...:
    a = sum(a, axis=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants