diff --git a/benchmark/core_shapes.yaml b/benchmark/core_shapes.yaml index 616eed0d5..63e0cfd1c 100644 --- a/benchmark/core_shapes.yaml +++ b/benchmark/core_shapes.yaml @@ -178,3 +178,14 @@ AttentionBenchmark: - [4, 8, 2048, 128] - [4, 8, 3072, 128] - [4, 8, 4096, 128] + +Conv3dBenchmark: + shapes: + - [104, 16, 32, 32, 32, 32, 4, 4, 4, 1, 0, 1] + - [64, 32, 18, 180, 18, 32, 5, 5, 5, 2, 1, 1] + - [4, 32, 110, 110, 10, 64, 5, 5, 5, 2, 1, 1] + - [4, 64, 110, 110, 10, 16, 5, 5, 5, 2, 1, 1] + - [16, 32, 120, 12, 12, 24, 3, 3, 3, 2, 1, 1] + - [16, 32, 240, 24, 24, 24, 3, 3, 3, 1, 1, 2] + - [16, 32, 24, 24, 24, 24, 3, 3, 3, 2, 2, 2] + - [16, 32, 24, 24, 24, 24, 3, 3, 3, 1, 2, 2] diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 6cc50495b..c63513313 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -10,6 +10,7 @@ from .performance_utils import ( Benchmark, Config, + GenericBenchmark, GenericBenchmark2DOnly, SkipVersion, generate_tensor_input, @@ -186,3 +187,67 @@ def count_nonzero_input_fn(shape, dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +class Conv3dBenchmark(GenericBenchmark): + def set_more_shapes(self): + self.shapes = [] + shapes = super().set_more_shapes() + shapes = [ + (n, c, d, h, w, o_c, k_d, k_h, k_w, stride, padding, groups) + for n in [2 * i for i in range(1, 2)] + for c in [16 * i for i in range(8, 9)] + for d in [128 * i for i in range(5, 9)] + for h in [128 * i for i in range(5, 9)] + for w in [128 * i for i in range(5, 9)] + for o_c in [128 * i for i in range(8, 9)] + for k_d in [32 * i for i in range(1, 4)] + for k_h in [32 * i for i in range(1, 4)] + for k_w in [32 * i for i in range(1, 4)] + for stride in [1, (2, 2, 2), (3, 3, 3)] + for padding in [0, (1, 1, 1), (0, 1, 2)] + for groups in [1, 2, 4, 8] + ] + return shapes + + +@pytest.mark.conv3d +def test_perf_conv3d(): + def conv3d_input_fn(shape, dtype, device): + ( + batch, + input_c, + input_d, + input_h, + input_w, + out_c, + kernel_d, + kernel_h, + kernel_w, + stride, + padding, + groups, + ) = shape + input_shape = (batch, input_c, input_d, input_h, input_w) + weight_shape = (out_c, input_c // groups, kernel_d, kernel_h, kernel_w) + input = torch.randn(size=input_shape, device=device, dtype=dtype) + + weight = torch.randn(size=weight_shape, device=device, dtype=dtype) + + yield { + "input": input, + "weight": weight, + "bias": None, + "groups": groups, + "stride": stride, + "padding": padding, + }, + + torch.backends.cudnn.allow_tf32 = False + bench = Conv3dBenchmark( + input_fn=conv3d_input_fn, + op_name="conv3d", + torch_op=torch.nn.functional.conv3d, + dtypes=FLOAT_DTYPES, + ) + bench.run() diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index 707a2d607..0422ad162 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -264,7 +264,7 @@ def set_more_shapes(self): @pytest.mark.conv2d def test_perf_conv2d(): - def conv2d_input_fn(shape, dtype, device): + def conv3d_input_fn(shape, dtype, device): ( batch, input_c, @@ -294,9 +294,9 @@ def conv2d_input_fn(shape, dtype, device): torch.backends.cudnn.allow_tf32 = False bench = ConvBenchmark( - input_fn=conv2d_input_fn, - op_name="conv2d", - torch_op=torch.nn.functional.conv2d, + input_fn=conv3d_input_fn, + op_name="conv3d", + torch_op=torch.nn.functional.conv3d, dtypes=FLOAT_DTYPES, ) bench.run() diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 354a14cca..01d47e732 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -22,6 +22,7 @@ from .conv1d import conv1d from .conv2d import conv2d from .conv_depthwise2d import _conv_depthwise2d +from .conv3d import conv3d from .cos import cos from .count_nonzero import count_nonzero from .cross_entropy_loss import cross_entropy_loss @@ -154,6 +155,7 @@ "clamp_tensor", "cos", "count_nonzero", + "conv3d", "diag", "diag_embed", "diagonal_backward", diff --git a/src/flag_gems/ops/conv3d.py b/src/flag_gems/ops/conv3d.py new file mode 100644 index 000000000..0e7eff411 --- /dev/null +++ b/src/flag_gems/ops/conv3d.py @@ -0,0 +1,209 @@ +import torch +import triton +import triton.language as tl + +from ..utils import libentry + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=4), + triton.Config({"BLOCK_SIZE": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=8), + triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=8), + ], + key=["N", "C", "D", "H", "W", "K"], +) +@triton.jit +def conv3d_kernel( + input_ptr, + weight_ptr, + output_ptr, + N, + C, + D, + H, + W, + K, + T, + R, + S, + stride_d, + stride_h, + stride_w, + pad_d, + pad_h, + pad_w, + dilation_d, + dilation_h, + dilation_w, + groups, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_idx = tl.program_id(1) + k_idx = tl.program_id(2) + + # Compute output dimensions + out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1 + out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1 + out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1 + + # Compute start and end indices for the block + block_start = pid * BLOCK_SIZE + block_end = min(block_start + BLOCK_SIZE, out_d * out_h * out_w) + + # Number of channels per group + C_per_group = C // groups + K_per_group = K // groups + + # Loop over the block + for idx in range(block_start, block_end): + # Compute the output indices + od = idx // (out_h * out_w) + oh = (idx % (out_h * out_w)) // out_w + ow = idx % out_w + + # Compute the corresponding input indices + id = od * stride_d - pad_d + ih = oh * stride_h - pad_h + iw = ow * stride_w - pad_w + + # Initialize the output value + value = tl.zeros((), dtype=tl.float32) + + # Perform convolution + group_id = k_idx // K_per_group + for t in range(T): + for r in range(R): + for s in range(S): + for c in range(C_per_group): + input_d = id + t * dilation_d + input_h = ih + r * dilation_h + input_w = iw + s * dilation_w + + # Check if the input indices are valid + in_bounds = ( + (0 <= input_d) + & (input_d < D) + & (0 <= input_h) + & (input_h < H) + & (0 <= input_w) + & (input_w < W) + ) + + if in_bounds: + input_idx = ( + n_idx * C * D * H * W + + (group_id * C_per_group + c) * D * H * W + + input_d * H * W + + input_h * W + + input_w + ) + weight_idx = ( + k_idx * C_per_group * T * R * S + + c * T * R * S + + t * R * S + + r * S + + s + ) + input_val = tl.load( + input_ptr + input_idx, mask=in_bounds, other=0.0 + ).to(tl.float32) + weight_val = tl.load(weight_ptr + weight_idx).to(tl.float32) + value += input_val * weight_val + + # Store the result in the output tensor + output_idx = ( + n_idx * K * out_d * out_h * out_w + + k_idx * out_d * out_h * out_w + + od * out_h * out_w + + oh * out_w + + ow + ) + tl.store(output_ptr + output_idx, value) + + +def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + """ + Implements a 3D convolution with groups support using Triton. + + Args: + input (torch.Tensor): Input tensor of shape (N, C, D, H, W). + weight (torch.Tensor): Weight tensor of shape (K, C/groups, T, R, S). + bias (torch.Tensor, optional): Bias tensor of shape (K,). + stride (int or tuple): Stride of the convolution. Default: 1. + padding (int or tuple): Padding added to all three dimensions. Default: 0. + dilation (int or tuple): Spacing between kernel elements. Default: 1. + groups (int): Number of blocked connections from input channels to output channels. Default: 1. + + Returns: + torch.Tensor: Output tensor. + """ + # Input shape: (N, C, D, H, W) + # Weight shape: (K, C/groups, T, R, S) + N, C, D, H, W = input.shape + K, C_per_group, T, R, S = weight.shape + + # Expand parameters to 3D + stride_d, stride_h, stride_w = ( + (stride, stride, stride) if isinstance(stride, int) else stride + ) + pad_d, pad_h, pad_w = ( + (padding, padding, padding) if isinstance(padding, int) else padding + ) + dilation_d, dilation_h, dilation_w = ( + (dilation, dilation, dilation) if isinstance(dilation, int) else dilation + ) + + # Compute output dimensions + out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1 + out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1 + out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1 + + # Allocate output tensor + output = torch.zeros( + (N, K, out_d, out_h, out_w), device=input.device, dtype=input.dtype + ) + + # Triton grid configuration + grid = lambda META: ( + triton.cdiv( + out_d * out_h * out_w, META["BLOCK_SIZE"] + ), # Number of output elements divided into blocks + N, # Number of batches + K, # Number of output channels + ) + + # Launch kernel + conv3d_kernel[grid]( + input, + weight, + output, + N, + C, + D, + H, + W, + K, + T, + R, + S, + stride_d, + stride_h, + stride_w, + pad_d, + pad_h, + pad_w, + dilation_d, + dilation_h, + dilation_w, + groups, + ) + + # Add bias if provided + if bias is not None: + output += bias.view(1, -1, 1, 1, 1) + + return output diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 37f4ab2c2..31a73405d 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -874,3 +874,53 @@ def test_accuracy_depthwise2d( inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1 ) gems_assert_close(res_out, ref_out, dtype) + + +SHAPE_CONV3D = [ + ((1, 3, 8, 8, 8), (8, 3, 3, 3, 3), 1), + ((2, 32, 16, 16, 16), (16, 16, 5, 5, 5), 2), + ((1, 12, 8, 8, 8), (8, 3, 3, 3, 3), 4), + # ((4, 32, 32, 32, 32),(32, 32, 7, 7, 7), 1), + # ((8, 64, 64, 64, 64),(64, 64, 9, 9, 9), 1), +] + + +@pytest.mark.conv3d +@pytest.mark.parametrize("shape_input, shape_weight, groups", SHAPE_CONV3D) +@pytest.mark.parametrize("strides", [1, (2, 2, 2), (3, 3, 3)]) +@pytest.mark.parametrize("paddings", [0, (1, 1, 1), (0, 1, 2)]) +@pytest.mark.parametrize("dilations", [1, (2, 2, 2), (1, 2, 3)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_conv3d( + shape_input, shape_weight, groups, strides, paddings, dilations, dtype +): + inp = torch.randn( + shape_input, dtype=dtype, device=flag_gems.device, requires_grad=True + ) + ref_inp = to_reference(inp, True) + torch.backends.cudnn.allow_tf32 = True + weight = torch.randn(shape_weight, dtype=dtype, device=flag_gems.device) + ref_weight = to_reference(weight, True) + + ref_out = torch.nn.functional.conv3d( + ref_inp, + ref_weight, + bias=None, + groups=groups, + stride=strides, + padding=paddings, + dilation=dilations, + ).to(dtype) + + res_out = flag_gems.conv3d( + inp, + weight, + bias=None, + groups=groups, + stride=strides, + padding=paddings, + dilation=dilations, + ) + + reduce_dim = shape_weight[-1] * shape_weight[-2] * shape_weight[-3] + gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim)