Skip to content

Commit

Permalink
More robust check for multiple integer indices in numba ravel_multidi…
Browse files Browse the repository at this point in the history
…mensional_idx rewrites
  • Loading branch information
ricardoV94 committed Dec 31, 2024
1 parent 4e85676 commit 8267d0e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
inc_subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable

Expand Down Expand Up @@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):

if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
Expand Down Expand Up @@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
int_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]

if len(int_idxs) != 1:
Expand Down

0 comments on commit 8267d0e

Please sign in to comment.