Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/cuda/tile/_ir/ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze
f'Rounding mode {rounding_mode.value} can only be used for float32 type, '
f'but got {dtype}')
if flush_to_zero:
if flush_to_zero and not math_op_def.support_flush_to_zero:
if not math_op_def.support_flush_to_zero:
raise TileTypeError(f'Flush to zero is not supported for {fn}')
if dtype != datatype.float32:
raise TileTypeError(
Expand Down Expand Up @@ -188,14 +188,13 @@ def memory_order_has_release(memory_order: MemoryOrder):
def get_dtype(ty: TileTy | datatype.DType | LooselyTypedScalar) -> datatype.DType | PointerTy:
if isinstance(ty, TileTy):
return ty.dtype
elif isinstance(ty, datatype.DType):
if isinstance(ty, datatype.DType):
return ty
elif isinstance(ty, PointerTy):
if isinstance(ty, PointerTy):
return ty
elif isinstance(ty, LooselyTypedScalar):
if isinstance(ty, LooselyTypedScalar):
return typeof_pyval(ty.value)
else:
raise TypeError(f"Cannot get dtype from {ty}")
raise TypeError(f"Cannot get dtype from {ty}")


def change_dtype(ty: TileTy | datatype.DType | PointerTy,
Expand Down