-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add Conv3d forward function #412
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from .performance_utils import ( | ||
Benchmark, | ||
Config, | ||
GenericBenchmark, | ||
GenericBenchmark2DOnly, | ||
SkipVersion, | ||
generate_tensor_input, | ||
|
@@ -186,3 +187,67 @@ def count_nonzero_input_fn(shape, dtype, device): | |
dtypes=FLOAT_DTYPES, | ||
) | ||
bench.run() | ||
|
||
|
||
class Conv3dBenchmark(GenericBenchmark): | ||
def set_more_shapes(self): | ||
self.shapes = [] | ||
shapes = super().set_more_shapes() | ||
shapes = [ | ||
(n, c, d, h, w, o_c, k_d, k_h, k_w, stride, padding, groups) | ||
for n in [2 * i for i in range(1, 2)] | ||
for c in [16 * i for i in range(8, 9)] | ||
for d in [128 * i for i in range(5, 9)] | ||
for h in [128 * i for i in range(5, 9)] | ||
for w in [128 * i for i in range(5, 9)] | ||
for o_c in [128 * i for i in range(8, 9)] | ||
for k_d in [32 * i for i in range(1, 4)] | ||
for k_h in [32 * i for i in range(1, 4)] | ||
for k_w in [32 * i for i in range(1, 4)] | ||
for stride in [1, (2, 2, 2), (3, 3, 3)] | ||
for padding in [0, (1, 1, 1), (0, 1, 2)] | ||
for groups in [1, 2, 4, 8] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. personally I think the number of test cases is too large. pick 20 shapes from classic networks and the performance data is convincing enough. |
||
] | ||
return shapes | ||
|
||
|
||
@pytest.mark.conv3d | ||
def test_perf_conv3d(): | ||
def conv3d_input_fn(shape, dtype, device): | ||
( | ||
batch, | ||
input_c, | ||
input_d, | ||
input_h, | ||
input_w, | ||
out_c, | ||
kernel_d, | ||
kernel_h, | ||
kernel_w, | ||
stride, | ||
padding, | ||
groups, | ||
) = shape | ||
input_shape = (batch, input_c, input_d, input_h, input_w) | ||
weight_shape = (out_c, input_c // groups, kernel_d, kernel_h, kernel_w) | ||
input = torch.randn(size=input_shape, device=device, dtype=dtype) | ||
|
||
weight = torch.randn(size=weight_shape, device=device, dtype=dtype) | ||
|
||
yield { | ||
"input": input, | ||
"weight": weight, | ||
"bias": None, | ||
"groups": groups, | ||
"stride": stride, | ||
"padding": padding, | ||
}, | ||
|
||
torch.backends.cudnn.allow_tf32 = False | ||
bench = Conv3dBenchmark( | ||
input_fn=conv3d_input_fn, | ||
op_name="conv3d", | ||
torch_op=torch.nn.functional.conv3d, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since your implementation of conv3d is not registered into aten library, the function called in benchmark is still torch op. please add registration in src/flag_gems/init.py and update the benchmark results. |
||
dtypes=FLOAT_DTYPES, | ||
) | ||
bench.run() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,7 +264,7 @@ def set_more_shapes(self): | |
|
||
@pytest.mark.conv2d | ||
def test_perf_conv2d(): | ||
def conv2d_input_fn(shape, dtype, device): | ||
def conv3d_input_fn(shape, dtype, device): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not time to enable benchmark of conv2d right now. so does conv3d. I'll mark it as skip. |
||
( | ||
batch, | ||
input_c, | ||
|
@@ -294,9 +294,9 @@ def conv2d_input_fn(shape, dtype, device): | |
|
||
torch.backends.cudnn.allow_tf32 = False | ||
bench = ConvBenchmark( | ||
input_fn=conv2d_input_fn, | ||
op_name="conv2d", | ||
torch_op=torch.nn.functional.conv2d, | ||
input_fn=conv3d_input_fn, | ||
op_name="conv3d", | ||
torch_op=torch.nn.functional.conv3d, | ||
dtypes=FLOAT_DTYPES, | ||
) | ||
bench.run() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from ..utils import libentry | ||
|
||
|
||
@libentry() | ||
@triton.autotune( | ||
configs=[ | ||
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=4), | ||
triton.Config({"BLOCK_SIZE": 32}, num_stages=3, num_warps=4), | ||
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=8), | ||
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=8), | ||
], | ||
key=["N", "C", "D", "H", "W", "K"], | ||
) | ||
@triton.jit | ||
def conv3d_kernel( | ||
input_ptr, | ||
weight_ptr, | ||
output_ptr, | ||
N, | ||
C, | ||
D, | ||
H, | ||
W, | ||
K, | ||
T, | ||
R, | ||
S, | ||
stride_d, | ||
stride_h, | ||
stride_w, | ||
pad_d, | ||
pad_h, | ||
pad_w, | ||
dilation_d, | ||
dilation_h, | ||
dilation_w, | ||
groups, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
n_idx = tl.program_id(1) | ||
k_idx = tl.program_id(2) | ||
|
||
# Compute output dimensions | ||
out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1 | ||
out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1 | ||
out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1 | ||
|
||
# Compute start and end indices for the block | ||
block_start = pid * BLOCK_SIZE | ||
block_end = min(block_start + BLOCK_SIZE, out_d * out_h * out_w) | ||
|
||
# Number of channels per group | ||
C_per_group = C // groups | ||
K_per_group = K // groups | ||
|
||
# Loop over the block | ||
for idx in range(block_start, block_end): | ||
# Compute the output indices | ||
od = idx // (out_h * out_w) | ||
oh = (idx % (out_h * out_w)) // out_w | ||
ow = idx % out_w | ||
|
||
# Compute the corresponding input indices | ||
id = od * stride_d - pad_d | ||
ih = oh * stride_h - pad_h | ||
iw = ow * stride_w - pad_w | ||
|
||
# Initialize the output value | ||
value = tl.zeros((), dtype=tl.float32) | ||
|
||
# Perform convolution | ||
group_id = k_idx // K_per_group | ||
for t in range(T): | ||
for r in range(R): | ||
for s in range(S): | ||
for c in range(C_per_group): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using multiple layers of loop might not be a good idea to compute convolution. try loading tensors with high-dimension indexes and using tl.dot primitive. |
||
input_d = id + t * dilation_d | ||
input_h = ih + r * dilation_h | ||
input_w = iw + s * dilation_w | ||
|
||
# Check if the input indices are valid | ||
in_bounds = ( | ||
(0 <= input_d) | ||
& (input_d < D) | ||
& (0 <= input_h) | ||
& (input_h < H) | ||
& (0 <= input_w) | ||
& (input_w < W) | ||
) | ||
|
||
if in_bounds: | ||
input_idx = ( | ||
n_idx * C * D * H * W | ||
+ (group_id * C_per_group + c) * D * H * W | ||
+ input_d * H * W | ||
+ input_h * W | ||
+ input_w | ||
) | ||
weight_idx = ( | ||
k_idx * C_per_group * T * R * S | ||
+ c * T * R * S | ||
+ t * R * S | ||
+ r * S | ||
+ s | ||
) | ||
input_val = tl.load( | ||
input_ptr + input_idx, mask=in_bounds, other=0.0 | ||
).to(tl.float32) | ||
weight_val = tl.load(weight_ptr + weight_idx).to(tl.float32) | ||
value += input_val * weight_val | ||
|
||
# Store the result in the output tensor | ||
output_idx = ( | ||
n_idx * K * out_d * out_h * out_w | ||
+ k_idx * out_d * out_h * out_w | ||
+ od * out_h * out_w | ||
+ oh * out_w | ||
+ ow | ||
) | ||
tl.store(output_ptr + output_idx, value) | ||
|
||
|
||
def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): | ||
""" | ||
Implements a 3D convolution with groups support using Triton. | ||
|
||
Args: | ||
input (torch.Tensor): Input tensor of shape (N, C, D, H, W). | ||
weight (torch.Tensor): Weight tensor of shape (K, C/groups, T, R, S). | ||
bias (torch.Tensor, optional): Bias tensor of shape (K,). | ||
stride (int or tuple): Stride of the convolution. Default: 1. | ||
padding (int or tuple): Padding added to all three dimensions. Default: 0. | ||
dilation (int or tuple): Spacing between kernel elements. Default: 1. | ||
groups (int): Number of blocked connections from input channels to output channels. Default: 1. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor. | ||
""" | ||
# Input shape: (N, C, D, H, W) | ||
# Weight shape: (K, C/groups, T, R, S) | ||
N, C, D, H, W = input.shape | ||
K, C_per_group, T, R, S = weight.shape | ||
|
||
# Expand parameters to 3D | ||
stride_d, stride_h, stride_w = ( | ||
(stride, stride, stride) if isinstance(stride, int) else stride | ||
) | ||
pad_d, pad_h, pad_w = ( | ||
(padding, padding, padding) if isinstance(padding, int) else padding | ||
) | ||
dilation_d, dilation_h, dilation_w = ( | ||
(dilation, dilation, dilation) if isinstance(dilation, int) else dilation | ||
) | ||
|
||
# Compute output dimensions | ||
out_d = (D + 2 * pad_d - dilation_d * (T - 1) - 1) // stride_d + 1 | ||
out_h = (H + 2 * pad_h - dilation_h * (R - 1) - 1) // stride_h + 1 | ||
out_w = (W + 2 * pad_w - dilation_w * (S - 1) - 1) // stride_w + 1 | ||
|
||
# Allocate output tensor | ||
output = torch.zeros( | ||
(N, K, out_d, out_h, out_w), device=input.device, dtype=input.dtype | ||
) | ||
|
||
# Triton grid configuration | ||
grid = lambda META: ( | ||
triton.cdiv( | ||
out_d * out_h * out_w, META["BLOCK_SIZE"] | ||
), # Number of output elements divided into blocks | ||
N, # Number of batches | ||
K, # Number of output channels | ||
) | ||
|
||
# Launch kernel | ||
conv3d_kernel[grid]( | ||
input, | ||
weight, | ||
output, | ||
N, | ||
C, | ||
D, | ||
H, | ||
W, | ||
K, | ||
T, | ||
R, | ||
S, | ||
stride_d, | ||
stride_h, | ||
stride_w, | ||
pad_d, | ||
pad_h, | ||
pad_w, | ||
dilation_d, | ||
dilation_h, | ||
dilation_w, | ||
groups, | ||
) | ||
|
||
# Add bias if provided | ||
if bias is not None: | ||
output += bias.view(1, -1, 1, 1, 1) | ||
|
||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -874,3 +874,53 @@ def test_accuracy_depthwise2d( | |
inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1 | ||
) | ||
gems_assert_close(res_out, ref_out, dtype) | ||
|
||
|
||
SHAPE_CONV3D = [ | ||
((1, 3, 8, 8, 8), (8, 3, 3, 3, 3), 1), | ||
((2, 32, 16, 16, 16), (16, 16, 5, 5, 5), 2), | ||
((1, 12, 8, 8, 8), (8, 3, 3, 3, 3), 4), | ||
# ((4, 32, 32, 32, 32),(32, 32, 7, 7, 7), 1), | ||
# ((8, 64, 64, 64, 64),(64, 64, 9, 9, 9), 1), | ||
] | ||
|
||
|
||
@pytest.mark.conv3d | ||
@pytest.mark.parametrize("shape_input, shape_weight, groups", SHAPE_CONV3D) | ||
@pytest.mark.parametrize("strides", [1, (2, 2, 2), (3, 3, 3)]) | ||
@pytest.mark.parametrize("paddings", [0, (1, 1, 1), (0, 1, 2)]) | ||
@pytest.mark.parametrize("dilations", [1, (2, 2, 2), (1, 2, 3)]) | ||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) | ||
def test_accuracy_conv3d( | ||
shape_input, shape_weight, groups, strides, paddings, dilations, dtype | ||
): | ||
inp = torch.randn( | ||
shape_input, dtype=dtype, device=flag_gems.device, requires_grad=True | ||
) | ||
ref_inp = to_reference(inp, True) | ||
torch.backends.cudnn.allow_tf32 = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually we set allow_tf32 as False since the precision of tf32 is not satisfying. |
||
weight = torch.randn(shape_weight, dtype=dtype, device=flag_gems.device) | ||
ref_weight = to_reference(weight, True) | ||
|
||
ref_out = torch.nn.functional.conv3d( | ||
ref_inp, | ||
ref_weight, | ||
bias=None, | ||
groups=groups, | ||
stride=strides, | ||
padding=paddings, | ||
dilation=dilations, | ||
).to(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gems_assert_close will cast the reference tensor to dtype. you don't need to do this again. |
||
|
||
res_out = flag_gems.conv3d( | ||
inp, | ||
weight, | ||
bias=None, | ||
groups=groups, | ||
stride=strides, | ||
padding=paddings, | ||
dilation=dilations, | ||
) | ||
|
||
reduce_dim = shape_weight[-1] * shape_weight[-2] * shape_weight[-3] | ||
gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we suggest setting 5 shapes for core mode.