diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 29d993893..2676e5bb2 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -10,6 +10,7 @@ from .performance_utils import ( Benchmark, Config, + GenericBenchmark, GenericBenchmark2DOnly, SkipVersion, generate_tensor_input, @@ -202,3 +203,19 @@ def count_nonzero_input_fn(shape, dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +@pytest.mark.diff +def test_perf_diff(): + def diff_input_fn(shape, cur_dtype, device): + inp = generate_tensor_input(shape, cur_dtype, device) + n = 1 + yield inp, n, 0 + + bench = GenericBenchmark( + input_fn=diff_input_fn, + op_name="diff", + torch_op=torch.diff, + dtypes=FLOAT_DTYPES + INT_DTYPES, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 35fa30634..878c6d610 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -41,6 +41,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("constant_pad_nd", constant_pad_nd, Autograd.disable), ("cumsum", cumsum, Autograd.disable), ("cummin", cummin, Autograd.disable), + ("diff", diff, Autograd.disable), ("div.Tensor", true_divide, Autograd.disable), ("div.Scalar", true_divide, Autograd.disable), ("div.Tensor_mode", div_mode, Autograd.disable), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 55db76e6a..a2c5d2624 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -30,6 +30,7 @@ from .diag import diag from .diag_embed import diag_embed from .diagonal import diagonal_backward +from .diff import diff from .div import div_mode, floor_divide, remainder, true_divide from .dropout import native_dropout from .embedding import embedding @@ -296,6 +297,7 @@ "logical_xor", "logical_not", "sort", + "diff", "nll_loss_forward", "nll_loss_backward", "nll_loss2d_forward", diff --git a/src/flag_gems/ops/diff.py b/src/flag_gems/ops/diff.py new file mode 100644 index 000000000..5752245f5 --- /dev/null +++ b/src/flag_gems/ops/diff.py @@ -0,0 +1,96 @@ +import torch +import triton +import triton.language as tl +from torch import Tensor, tensor + +from .. import runtime +from ..runtime import torch_device_fn +from ..utils import dim_compress, libentry, libtuner +from ..utils import triton_lang_extension as tle + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("diff_1d"), + key=["N"], +) +@triton.jit +def diff_kernel_1d(in_ptr, out_ptr, N, N_bound, BLOCK_DIFF: tl.constexpr): + pid = tle.program_id(0) + + in_offsets = pid * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF) + mask_in = in_offsets < N_bound - 1 + in_block = tl.load(in_ptr + in_offsets, mask_in) + next_block = tl.load(in_ptr + in_offsets + 1, mask_in) + tl.store(out_ptr + in_offsets, next_block - in_block, mask_in) + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("diff"), + key=["M", "N"], +) +@triton.jit +def diff_kernel_2d( + in_ptr, out_ptr, M, N, N_bound, BLOCK_M: tl.constexpr, BLOCK_DIFF: tl.constexpr +): + pid_M = tle.program_id(0) + pid_diff = tle.program_id(1) + + M_offsets = pid_M * BLOCK_M + tl.arange(0, BLOCK_M) + mask_M = M_offsets < M + + in_offsets_diff = pid_diff * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF) + mask_in_diff = in_offsets_diff < N_bound - 1 + in_offsets = M_offsets[:, None] * N + in_offsets_diff + mask_in = mask_M[:, None] & mask_in_diff + out_offsets = M_offsets[:, None] * N + in_offsets_diff + + in_block = tl.load(in_ptr + in_offsets, mask_in) + next_block = tl.load(in_ptr + in_offsets + 1, mask_in) + tl.store(out_ptr + out_offsets, next_block - in_block, mask_in) + + +def diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor: + if prepend is not None: + input = torch.cat([prepend, input], dim=dim) + if append is not None: + input = torch.cat([input, append], dim=dim) + + if n <= 0: + return input + + shape = list(input.shape) + dim = dim % input.ndim + reduce_len = shape[dim] + + if n >= reduce_len: + empty_tensor = tensor([], dtype=input.dtype, device=input.device) + return torch.reshape(empty_tensor, shape[:dim] + [0] + shape[(dim + 1) :]) + + input = dim_compress(input, dim) + N = reduce_len + M = input.numel() // N + + output = torch.zeros(input.shape, device=input.device, dtype=input.dtype) + + n_steps = n + while n_steps > 0: + cur_in_diff_len = N - (n - n_steps) + if len(shape) == 1: + grid = lambda meta: (triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]),) + with torch_device_fn.device(input.device): + diff_kernel_1d[grid](input, output, N, cur_in_diff_len) + else: + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]), + ) + with torch_device_fn.device(input.device): + diff_kernel_2d[grid](input, output, M, N, cur_in_diff_len) + n_steps -= 1 + input.copy_(output) + + output = output[..., : (N - n)].contiguous() + output = torch.moveaxis(output, -1, dim) + return output diff --git a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml index c6cdcf47b..ba4e42632 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml @@ -2762,6 +2762,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml index c3d5b3951..37c8a22c3 100644 --- a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml @@ -510,6 +510,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 879d9474a..77ad685f4 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -493,6 +493,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 5847d1eb4..25f9ea037 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -927,3 +927,28 @@ def test_accuracy_depthwise2d( inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1 ) gems_assert_close(res_out, ref_out, dtype) + + +DIFF_N_VALUES = list(range(0, 10)) + + +@pytest.mark.diff +@pytest.mark.parametrize("shape", [(1024**3,)] + REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) +@pytest.mark.parametrize("n", DIFF_N_VALUES) +def test_accuracy_diff(shape, dim, dtype, n): + if dtype in INT_DTYPES: + inp = torch.randint( + low=-10, high=11, size=shape, dtype=dtype, device=flag_gems.device + ) + else: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + ref_inp = to_reference(inp) + + ref_out = torch.diff(ref_inp, n, dim % inp.ndim) + with flag_gems.use_gems(): + res_out = torch.diff(inp, n, dim) + + reduce_dim = shape[dim % inp.ndim] + gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim, equal_nan=True)