diff --git a/src/cuda/tile/_ir/ops_utils.py b/src/cuda/tile/_ir/ops_utils.py index b25716c..c96cedd 100644 --- a/src/cuda/tile/_ir/ops_utils.py +++ b/src/cuda/tile/_ir/ops_utils.py @@ -154,7 +154,7 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze f'Rounding mode {rounding_mode.value} can only be used for float32 type, ' f'but got {dtype}') if flush_to_zero: - if flush_to_zero and not math_op_def.support_flush_to_zero: + if not math_op_def.support_flush_to_zero: raise TileTypeError(f'Flush to zero is not supported for {fn}') if dtype != datatype.float32: raise TileTypeError( @@ -188,14 +188,13 @@ def memory_order_has_release(memory_order: MemoryOrder): def get_dtype(ty: TileTy | datatype.DType | LooselyTypedScalar) -> datatype.DType | PointerTy: if isinstance(ty, TileTy): return ty.dtype - elif isinstance(ty, datatype.DType): + if isinstance(ty, datatype.DType): return ty - elif isinstance(ty, PointerTy): + if isinstance(ty, PointerTy): return ty - elif isinstance(ty, LooselyTypedScalar): + if isinstance(ty, LooselyTypedScalar): return typeof_pyval(ty.value) - else: - raise TypeError(f"Cannot get dtype from {ty}") + raise TypeError(f"Cannot get dtype from {ty}") def change_dtype(ty: TileTy | datatype.DType | PointerTy,