diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 7ba1908e60..572d2bcab6 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -85,7 +85,7 @@ inc_subtensor, indices_from_subtensor, ) -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node): if any( ( - (isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")) + (isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes) or isinstance(idx.type, NoneTypeT) ) for idx in idxs @@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node): int_idxs = [ (i, idx) for i, idx in enumerate(idxs) - if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int")) + if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) ] if len(int_idxs) != 1: