Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add Conv3d forward function #412

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmark/core_shapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,14 @@ AttentionBenchmark:
- [4, 8, 2048, 128]
- [4, 8, 3072, 128]
- [4, 8, 4096, 128]

Conv3dBenchmark:
shapes:
- [104, 16, 32, 32, 32, 32, 4, 4, 4, 1, 0, 1]
- [64, 32, 18, 180, 18, 32, 5, 5, 5, 2, 1, 1]
- [4, 32, 110, 110, 10, 64, 5, 5, 5, 2, 1, 1]
- [4, 64, 110, 110, 10, 16, 5, 5, 5, 2, 1, 1]
- [16, 32, 120, 12, 12, 24, 3, 3, 3, 2, 1, 1]
- [16, 32, 240, 24, 24, 24, 3, 3, 3, 1, 1, 2]
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 2, 2, 2]
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 1, 2, 2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we suggest setting 5 shapes for core mode.

65 changes: 65 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .performance_utils import (
Benchmark,
Config,
GenericBenchmark,
GenericBenchmark2DOnly,
SkipVersion,
generate_tensor_input,
Expand Down Expand Up @@ -186,3 +187,67 @@ def count_nonzero_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


class Conv3dBenchmark(GenericBenchmark):
def set_more_shapes(self):
self.shapes = []
shapes = super().set_more_shapes()
shapes = [
(n, c, d, h, w, o_c, k_d, k_h, k_w, stride, padding, groups)
for n in [2 * i for i in range(1, 2)]
for c in [16 * i for i in range(8, 9)]
for d in [128 * i for i in range(5, 9)]
for h in [128 * i for i in range(5, 9)]
for w in [128 * i for i in range(5, 9)]
for o_c in [128 * i for i in range(8, 9)]
for k_d in [32 * i for i in range(1, 4)]
for k_h in [32 * i for i in range(1, 4)]
for k_w in [32 * i for i in range(1, 4)]
for stride in [1, (2, 2, 2), (3, 3, 3)]
for padding in [0, (1, 1, 1), (0, 1, 2)]
for groups in [1, 2, 4, 8]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I think the number of test cases is too large. pick 20 shapes from classic networks and the performance data is convincing enough.

]
return shapes


@pytest.mark.conv3d
def test_perf_conv3d():
def conv3d_input_fn(shape, dtype, device):
(
batch,
input_c,
input_d,
input_h,
input_w,
out_c,
kernel_d,
kernel_h,
kernel_w,
stride,
padding,
groups,
) = shape
input_shape = (batch, input_c, input_d, input_h, input_w)
weight_shape = (out_c, input_c // groups, kernel_d, kernel_h, kernel_w)
input = torch.randn(size=input_shape, device=device, dtype=dtype)

weight = torch.randn(size=weight_shape, device=device, dtype=dtype)

yield {
"input": input,
"weight": weight,
"bias": None,
"groups": groups,
"stride": stride,
"padding": padding,
},

torch.backends.cudnn.allow_tf32 = False
bench = Conv3dBenchmark(
input_fn=conv3d_input_fn,
op_name="conv3d",
torch_op=torch.nn.functional.conv3d,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since your implementation of conv3d is not registered into aten library, the function called in benchmark is still torch op. please add registration in src/flag_gems/init.py and update the benchmark results.

dtypes=FLOAT_DTYPES,
)
bench.run()
8 changes: 4 additions & 4 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def set_more_shapes(self):

@pytest.mark.conv2d
def test_perf_conv2d():
def conv2d_input_fn(shape, dtype, device):
def conv3d_input_fn(shape, dtype, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not time to enable benchmark of conv2d right now. so does conv3d. I'll mark it as skip.

(
batch,
input_c,
Expand Down Expand Up @@ -294,9 +294,9 @@ def conv2d_input_fn(shape, dtype, device):

torch.backends.cudnn.allow_tf32 = False
bench = ConvBenchmark(
input_fn=conv2d_input_fn,
op_name="conv2d",
torch_op=torch.nn.functional.conv2d,
input_fn=conv3d_input_fn,
op_name="conv3d",
torch_op=torch.nn.functional.conv3d,
dtypes=FLOAT_DTYPES,
)
bench.run()
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .conv1d import conv1d
from .conv2d import conv2d
from .conv_depthwise2d import _conv_depthwise2d
from .conv3d import conv3d
from .cos import cos
from .count_nonzero import count_nonzero
from .cross_entropy_loss import cross_entropy_loss
Expand Down Expand Up @@ -154,6 +155,7 @@
"clamp_tensor",
"cos",
"count_nonzero",
"conv3d",
"diag",
"diag_embed",
"diagonal_backward",
Expand Down
209 changes: 209 additions & 0 deletions src/flag_gems/ops/conv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE": 32}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=8),
],
key=["N", "C", "D", "H", "W", "K"],
)
@triton.jit
def conv3d_kernel(
input_ptr,
weight_ptr,
output_ptr,
N,
C,
D,
H,
W,
K,
T,
R,
S,
stride_d,
stride_h,
stride_w,
pad_d,
pad_h,
pad_w,
dilation_d,
dilation_h,
dilation_w,
groups,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_idx = tl.program_id(1)
k_idx = tl.program_id(2)

# Compute output dimensions
out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1
out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1
out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1

# Compute start and end indices for the block
block_start = pid * BLOCK_SIZE
block_end = min(block_start + BLOCK_SIZE, out_d * out_h * out_w)

# Number of channels per group
C_per_group = C // groups
K_per_group = K // groups

# Loop over the block
for idx in range(block_start, block_end):
# Compute the output indices
od = idx // (out_h * out_w)
oh = (idx % (out_h * out_w)) // out_w
ow = idx % out_w

# Compute the corresponding input indices
id = od * stride_d - pad_d
ih = oh * stride_h - pad_h
iw = ow * stride_w - pad_w

# Initialize the output value
value = tl.zeros((), dtype=tl.float32)

# Perform convolution
group_id = k_idx // K_per_group
for t in range(T):
for r in range(R):
for s in range(S):
for c in range(C_per_group):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using multiple layers of loop might not be a good idea to compute convolution. try loading tensors with high-dimension indexes and using tl.dot primitive.

input_d = id + t * dilation_d
input_h = ih + r * dilation_h
input_w = iw + s * dilation_w

# Check if the input indices are valid
in_bounds = (
(0 <= input_d)
& (input_d < D)
& (0 <= input_h)
& (input_h < H)
& (0 <= input_w)
& (input_w < W)
)

if in_bounds:
input_idx = (
n_idx * C * D * H * W
+ (group_id * C_per_group + c) * D * H * W
+ input_d * H * W
+ input_h * W
+ input_w
)
weight_idx = (
k_idx * C_per_group * T * R * S
+ c * T * R * S
+ t * R * S
+ r * S
+ s
)
input_val = tl.load(
input_ptr + input_idx, mask=in_bounds, other=0.0
).to(tl.float32)
weight_val = tl.load(weight_ptr + weight_idx).to(tl.float32)
value += input_val * weight_val

# Store the result in the output tensor
output_idx = (
n_idx * K * out_d * out_h * out_w
+ k_idx * out_d * out_h * out_w
+ od * out_h * out_w
+ oh * out_w
+ ow
)
tl.store(output_ptr + output_idx, value)


def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
"""
Implements a 3D convolution with groups support using Triton.

Args:
input (torch.Tensor): Input tensor of shape (N, C, D, H, W).
weight (torch.Tensor): Weight tensor of shape (K, C/groups, T, R, S).
bias (torch.Tensor, optional): Bias tensor of shape (K,).
stride (int or tuple): Stride of the convolution. Default: 1.
padding (int or tuple): Padding added to all three dimensions. Default: 0.
dilation (int or tuple): Spacing between kernel elements. Default: 1.
groups (int): Number of blocked connections from input channels to output channels. Default: 1.

Returns:
torch.Tensor: Output tensor.
"""
# Input shape: (N, C, D, H, W)
# Weight shape: (K, C/groups, T, R, S)
N, C, D, H, W = input.shape
K, C_per_group, T, R, S = weight.shape

# Expand parameters to 3D
stride_d, stride_h, stride_w = (
(stride, stride, stride) if isinstance(stride, int) else stride
)
pad_d, pad_h, pad_w = (
(padding, padding, padding) if isinstance(padding, int) else padding
)
dilation_d, dilation_h, dilation_w = (
(dilation, dilation, dilation) if isinstance(dilation, int) else dilation
)

# Compute output dimensions
out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1
out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1
out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1

# Allocate output tensor
output = torch.zeros(
(N, K, out_d, out_h, out_w), device=input.device, dtype=input.dtype
)

# Triton grid configuration
grid = lambda META: (
triton.cdiv(
out_d * out_h * out_w, META["BLOCK_SIZE"]
), # Number of output elements divided into blocks
N, # Number of batches
K, # Number of output channels
)

# Launch kernel
conv3d_kernel[grid](
input,
weight,
output,
N,
C,
D,
H,
W,
K,
T,
R,
S,
stride_d,
stride_h,
stride_w,
pad_d,
pad_h,
pad_w,
dilation_d,
dilation_h,
dilation_w,
groups,
)

# Add bias if provided
if bias is not None:
output += bias.view(1, -1, 1, 1, 1)

return output
50 changes: 50 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,53 @@ def test_accuracy_depthwise2d(
inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1
)
gems_assert_close(res_out, ref_out, dtype)


SHAPE_CONV3D = [
((1, 3, 8, 8, 8), (8, 3, 3, 3, 3), 1),
((2, 32, 16, 16, 16), (16, 16, 5, 5, 5), 2),
((1, 12, 8, 8, 8), (8, 3, 3, 3, 3), 4),
# ((4, 32, 32, 32, 32),(32, 32, 7, 7, 7), 1),
# ((8, 64, 64, 64, 64),(64, 64, 9, 9, 9), 1),
]


@pytest.mark.conv3d
@pytest.mark.parametrize("shape_input, shape_weight, groups", SHAPE_CONV3D)
@pytest.mark.parametrize("strides", [1, (2, 2, 2), (3, 3, 3)])
@pytest.mark.parametrize("paddings", [0, (1, 1, 1), (0, 1, 2)])
@pytest.mark.parametrize("dilations", [1, (2, 2, 2), (1, 2, 3)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_accuracy_conv3d(
shape_input, shape_weight, groups, strides, paddings, dilations, dtype
):
inp = torch.randn(
shape_input, dtype=dtype, device=flag_gems.device, requires_grad=True
)
ref_inp = to_reference(inp, True)
torch.backends.cudnn.allow_tf32 = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we set allow_tf32 as False since the precision of tf32 is not satisfying.

weight = torch.randn(shape_weight, dtype=dtype, device=flag_gems.device)
ref_weight = to_reference(weight, True)

ref_out = torch.nn.functional.conv3d(
ref_inp,
ref_weight,
bias=None,
groups=groups,
stride=strides,
padding=paddings,
dilation=dilations,
).to(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gems_assert_close will cast the reference tensor to dtype. you don't need to do this again.


res_out = flag_gems.conv3d(
inp,
weight,
bias=None,
groups=groups,
stride=strides,
padding=paddings,
dilation=dilations,
)

reduce_dim = shape_weight[-1] * shape_weight[-2] * shape_weight[-3]
gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim)
Loading