Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba dispatch of Elemwise of ScalarLoop is broken #1130

Closed
aseyboldt opened this issue Dec 17, 2024 · 2 comments · Fixed by #1137
Closed

Numba dispatch of Elemwise of ScalarLoop is broken #1130

aseyboldt opened this issue Dec 17, 2024 · 2 comments · Fixed by #1137

Comments

@aseyboldt
Copy link
Member

Description

Originally reported here: pymc-devs/nutpie#163

Here is a smaller reproducer:

import pytensor
import pytensor.tensor as pt
import numpy as np

a = pytensor.scalar.get_scalar_type("float64")()
loop = pytensor.scalar.ScalarLoop([a], [pytensor.scalar.add(a, a)])
x = pt.tensor("x", shape=(3,))
elem = pt.elemwise.Elemwise(loop)(3, x)
elem.eval({x: np.ones(3)})
# Returns array([8., 8., 8.])

# But compiling fails:
func = pytensor.function([x], elem, mode="NUMBA")
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py:357, in numba_funcify_Elemwise(op, node, **kwargs)
    355 if not isinstance(op.scalar_op, Composite):
    356     scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357     scalar_node = op.scalar_op.make_node(*scalar_inputs)
    359 scalar_op_fn = numba_funcify(
    360     op.scalar_op,
    361     node=scalar_node,
   (...)
    364     **kwargs,
    365 )
    367 nin = len(node.inputs)

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:180, in ScalarLoop.make_node(self, n_steps, *inputs)
    178 cloned_constant = cloned_inputs[len(cloned_update) :]
    179 # This will fail if the cloned init have a different dtype than the cloned_update
--> 180 op = ScalarLoop(
    181     init=cloned_init,
    182     update=cloned_update,
    183     constant=cloned_constant,
    184     until=cloned_until,
    185     name=self.name,
    186 )
    187 node = op.make_node(n_steps, *inputs)
    188 return node

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:69, in ScalarLoop.__init__(self, init, update, constant, until, name)
     66     inputs, outputs = clone([*init, *constant], update)
     68 self.is_while = bool(until)
---> 69 self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
     70 self._validate_updates(self.inputs, self.outputs)
     72 self.inputs_type = tuple(input.type for input in self.inputs)

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/basic.py:3992, in ScalarInnerGraphOp._cleanup_graph(self, inputs, outputs)
   3990 for node in fgraph.apply_nodes:
   3991     if not isinstance(node.op, ScalarOp):
-> 3992         raise TypeError(
   3993             f"The fgraph of {self.__class__.__name__} must be exclusively "
   3994             "composed of scalar operations."
   3995         )
   3997 # Run MergeOptimization to avoid duplicated nodes
   3998 MergeOptimizer().rewrite(fgraph)

TypeError: The fgraph of ScalarLoop must be exclusively composed of scalar operations.

The debugger shows that it tries to create a ScalarLoop Node with inputs that are Rank 0 Tensors instead of scalars:

> /home/adr/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py(357)numba_funcify_Elemwise()
    355     if not isinstance(op.scalar_op, Composite):
    356         scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357         scalar_node = op.scalar_op.make_node(*scalar_inputs)
    358 
    359     scalar_op_fn = numba_funcify(

ipdb>  p scalar_inputs
[<Scalar(int8, shape=())>, <Scalar(float64, shape=())>]
ipdb>  p scalar_inputs[0].type
TensorType(int8, shape=())
ipdb>  exit
@ricardoV94
Copy link
Member

That if branch is strange, but if you add ScalarLoop to the isinstance check what happens?

Also I'm not sure we have a dispatch for ScalarLoop (which should be simple), I think it will go to python mode?

@ricardoV94
Copy link
Member

#299

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants