diff --git a/src/flag_gems/runtime/backend/_iluvatar/__init__.py b/src/flag_gems/runtime/backend/_iluvatar/__init__.py index 9b33bc874..51a862810 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/__init__.py +++ b/src/flag_gems/runtime/backend/_iluvatar/__init__.py @@ -4,6 +4,6 @@ vendor_name="iluvatar", device_name="cuda", device_query_cmd="ixsmi" ) -CUSTOMIZED_UNUSED_OPS = ("randperm", "topk", "sort", "multinomial") +CUSTOMIZED_UNUSED_OPS = ("scatter", "quantile", "randperm", "mv") __all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py index 0af3df6ec..943990908 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py @@ -1,4 +1,5 @@ from .bmm import bmm +from .div import div_mode, floor_divide, remainder, true_divide from .mm import mm -__all__ = ["bmm", "mm"] +__all__ = ["bmm", "mm", "div_mode", "floor_divide", "remainder", "true_divide"] diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/div.py b/src/flag_gems/runtime/backend/_iluvatar/ops/div.py new file mode 100644 index 000000000..ad9c048cb --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/div.py @@ -0,0 +1,218 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import pointwise_dynamic, tl_extra_shim + +div_rn = tl_extra_shim.div_rn +div_rz = tl_extra_shim.div_rz +fmod = tl_extra_shim.fmod +trunc = tl_extra_shim.trunc + + +@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) +@triton.jit +def true_div_func(x, y): + return x / y + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) +@triton.jit +def true_div_func_tensor_scalar(x, y): + return x / y + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) +@triton.jit +def true_div_func_scalar_tensor(x, y): + return x / y + + +def true_divide(A, B): + logging.debug("GEMS TRUE_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return true_div_func(A, B) + elif isinstance(A, torch.Tensor): + return true_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return true_div_func_scalar_tensor(A, B) + else: + # Both scalar + return torch.tensor(A / B) + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func(x, y): + return trunc(x / y) + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func_tensor_scalar(x, y): + return trunc(div_rz(x, y)) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func_scalar_tensor(x, y): + return trunc(div_rz(x, y)) + + +def trunc_divide(A, B): + logging.debug("GEMS TRUNC_DIVIDE iluvatar") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return trunc_div_func(A, B) + elif isinstance(A, torch.Tensor): + return trunc_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return trunc_div_func_scalar_tensor(A, B) + else: + # Both scalar + return torch.tensor(A / B) + + +@triton.jit +def _int_floordiv(x, y): + # TODO: request Triton to add an integer remainder builtin + # The semantic of Triton floordiv differs from Pytorch/Numpy + # Triton floordiv equates to + # (x - np.fmod(x, y)) / y + # whereas Pytorch floordiv is + # (x - np.remainder(x, y)) y + # The results show a one off difference when + # C1) x and y have opposite signs + # and C2) x is not multiples of y. + # Apart from the above, there's an erroneous case x // 0 returns -1 + # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 + # but this special case is coalesced into the c1 and c2 check so + # there's extra handling. + r = x % y + c1 = r != 0 + c2 = (x < 0) ^ (y < 0) + return tl.where(c1 & c2, x // y - 1, x // y) + + +# TO be consistent with python, numpy and torch, we have to implement it in the +# following way. +# CPython +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 +# numpy +# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 +# torch +# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 +@triton.jit +def _float_floordiv(x, y): + # NOTE: fmod's sign is the same as the dividend + remainder = fmod(x, y) + imperfect = remainder != 0.0 + different_sign = (x < 0) ^ (y < 0) + + # NOTE: we have to use div_rn explicitly here + q = div_rn(x - remainder, y) + q = tl.where(imperfect & different_sign, q - 1, q) + + floor_q = tl.math.floor(q) + c = q - floor_q > 0.5 + floor_q = tl.where(c, floor_q + 1.0, floor_q) + + q_is_zeros = q == 0.0 + floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) + + is_div_by_zero = y == 0.0 + float_division = x / y + out = tl.where(is_div_by_zero, float_division, floor_q) + return out + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func(x, y): + if x.type.scalar.is_int() & x.type.scalar.is_int(): + return _int_floordiv(x, y) + else: + return _float_floordiv(x, y) + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func_tensor_scalar(x, y): + if x.type.scalar.is_int() & x.type.scalar.is_int(): + return _int_floordiv(x, y) + else: + return _float_floordiv(x, y) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func_scalar_tensor(x, y): + if x.type.scalar.is_int() & x.type.scalar.is_int(): + return _int_floordiv(x, y) + else: + return _float_floordiv(x, y) + + +def floor_divide(A, B): + logging.debug("GEMS FLOOR_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return floor_div_func(A, B) + elif isinstance(A, torch.Tensor): + return floor_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return floor_div_func_scalar_tensor(A, B) + else: + # Both scalar + return torch.tensor(A // B) + + +def div_mode(A, B, rounding_mode=None): + if rounding_mode is None: + return true_divide(A, B) + elif rounding_mode == "trunc": + return trunc_divide(A, B) + elif rounding_mode == "floor": + return floor_divide(A, B) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) + + +@triton.jit +def _remainder(x, y): + r = x % y + c1 = r != 0 + c2 = (x < 0) ^ (y < 0) + return tl.where(c1 & c2, r + y, r) + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def rem_tt(x, y): + return _remainder(x, y) + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def rem_ts(x, y): + return _remainder(x, y) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def rem_st(x, y): + return _remainder(x, y) + + +def remainder(A, B): + logging.debug("GEMS FLOOR_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return rem_tt(A, B) + elif isinstance(A, torch.Tensor): + return rem_ts(A, B) + elif isinstance(B, torch.Tensor): + return rem_st(A, B) + else: + # Both scalar + return torch.tensor(A % B) diff --git a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml index 09cdc1b53..1e47896ab 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml @@ -3235,7 +3235,6 @@ var_mean: - 8 block_n: - 1024 - - 2048 warps: - 4 - 8 diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 52a93d56a..1f84243e2 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -370,7 +370,7 @@ def test_accuracy_div_scalar_scalar(dtype): def test_accuracy_trunc_div(shape, dtype): inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) - upcast = True if flag_gems.vendor_name != "cambricon" else False + upcast = True if flag_gems.vendor_name not in ["cambricon", "iluvatar"] else False ref_inp1 = to_reference(inp1, upcast) ref_inp2 = to_reference(inp2, upcast)