Skip to content

Commit

Permalink
[Operator] register batch_norm backward
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Feb 6, 2025
1 parent e2be412 commit 01bee17
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 192 deletions.
3 changes: 2 additions & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("arange.start_step", arange_start, Autograd.disable),
("arange.start", arange_start, Autograd.disable),
("arange", arange, Autograd.disable),
("batch_norm", batch_norm, Autograd.enable),
("native_batch_norm", batch_norm, Autograd.disable),
("native_batch_norm_backward", batch_norm_backward, Autograd.disable),
("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .argmax import argmax
from .argmin import argmin
from .attention import scaled_dot_product_attention
from .batch_norm import batch_norm
from .batch_norm import batch_norm, batch_norm_backward
from .bitwise_and import (
bitwise_and_scalar,
bitwise_and_scalar_tensor,
Expand Down Expand Up @@ -151,6 +151,7 @@
"arange",
"arange_start",
"batch_norm",
"batch_norm_backward",
"bitwise_and_tensor",
"bitwise_and_scalar",
"bitwise_and_scalar_tensor",
Expand Down
263 changes: 111 additions & 152 deletions src/flag_gems/ops/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry, tl_extra_shim
from ..utils.type_utils import get_accumulator_dtype

rsqrt = tl_extra_shim.rsqrt

Expand Down Expand Up @@ -63,8 +62,6 @@ def batch_norm_forward_kernel(
output_spatial_stride,
momentum,
eps,
affine: tl.constexpr,
save_stats: tl.constexpr,
is_train: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -114,9 +111,8 @@ def batch_norm_forward_kernel(
inv_std = rsqrt(var + eps)
mean = final_mean

if save_stats:
tl.store(feat_pid + mean_pointer, mean)
tl.store(feat_pid + inv_std_pointer, inv_std)
tl.store(feat_pid + mean_pointer, mean)
tl.store(feat_pid + inv_std_pointer, inv_std)

running_mean_pointer += feat_pid
running_var_pointer += feat_pid
Expand All @@ -135,12 +131,13 @@ def batch_norm_forward_kernel(
mean = tl.load(feat_pid + running_mean_pointer)
inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)

if affine:
weight = tl.load(feat_pid + weight_pointer)
bias = tl.load(feat_pid + bias_pointer)

if weight_pointer:
weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
else:
weight = 1.0
if bias_pointer:
bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
else:
bias = 0.0

for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
Expand Down Expand Up @@ -203,7 +200,9 @@ def batch_norm_backward_kernel(
input_grad_batch_stride,
input_grad_feat_stride,
input_grad_spatial_stride,
affine: tl.constexpr,
input_grad_mask: tl.constexpr,
weight_grad_mask: tl.constexpr,
bias_grad_mask: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand Down Expand Up @@ -250,11 +249,16 @@ def batch_norm_backward_kernel(
term1 = tl.sum(term1)
term2 = tl.sum(term2)

if affine:
weight = tl.load(feat_pid + weight_pointer)
weight_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
bias_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if weight_grad_mask:
tl.store(feat_pid + weight_grad_pointer, term1)
if bias_grad_mask:
tl.store(feat_pid + bias_grad_pointer, term2)

if not input_grad_mask:
return

if weight_pointer:
weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
else:
weight = 1.0

Expand Down Expand Up @@ -306,152 +310,107 @@ def batch_norm_backward_kernel(
mask=batch_mask[:, None] & spatial_mask[None, :],
)

if affine:
weight_grad_acc += curr_pre_lin * curr_output_grad
bias_grad_acc += curr_output_grad

if affine:
tl.store(feat_pid + weight_grad_pointer, tl.sum(weight_grad_acc))
tl.store(feat_pid + bias_grad_pointer, tl.sum(bias_grad_acc))


class BatchNorm(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: Tensor,
weight=None,
bias=None,
running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
running_var=None,
training=False, # (self.running_mean is None) and (self.running_var is None)
momentum=0.1,
eps=1e-05,
cudnn_enable=True,
):
logging.debug("GEMS BATCHNORM FORWARD")

input_3d = make_3d_for_bn(input)

affine = weight is not None and bias is not None
requires_grad = (
input.requires_grad
or (affine and weight.requires_grad)
or (affine and bias.requires_grad)
)

batch_dim, feat_dim, spatial_dim = input_3d.shape
output = torch.empty_like(input_3d)

if requires_grad:
acc_type = get_accumulator_dtype(input.dtype)
mean = torch.empty(feat_dim, device=input.device, dtype=acc_type)
inv_std = torch.empty(feat_dim, device=input.device, dtype=acc_type)

else:
mean = inv_std = None

running_mean = input if running_mean is None else running_mean
running_var = input if running_var is None else running_var
def batch_norm(
input: Tensor,
weight=None,
bias=None,
running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
running_var=None,
training=False, # (self.running_mean is None) and (self.running_var is None)
momentum=0.1,
eps=1e-05,
):
logging.debug("GEMS BATCHNORM FORWARD")

input_3d = make_3d_for_bn(input)

batch_dim, feat_dim, spatial_dim = input_3d.shape
output = torch.empty_like(input_3d)

mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)

running_mean = input if running_mean is None else running_mean
running_var = input if running_var is None else running_var

# Launches 1D grid where each program operates over one feature.
with torch_device_fn.device(input.device):
batch_norm_forward_kernel[(feat_dim,)](
input_3d,
weight,
bias,
mean,
inv_std,
output,
running_mean,
running_var,
batch_dim,
spatial_dim,
*input_3d.stride(),
*output.stride(),
momentum,
eps,
is_train=training,
)

# Launches 1D grid where each program operates over one feature.
with torch_device_fn.device(input.device):
batch_norm_forward_kernel[(feat_dim,)](
input_3d,
weight,
bias,
mean,
inv_std,
output,
running_mean,
running_var,
batch_dim,
spatial_dim,
*input_3d.stride(),
*output.stride(),
momentum,
eps,
affine=affine,
save_stats=requires_grad,
is_train=training,
)
return output.view_as(input), mean, inv_std

ctx.affine = affine
if requires_grad:
ctx.save_for_backward(input, mean, inv_std, weight)

return output.view_as(input)
def batch_norm_backward(
grad_out,
input,
weight=None,
running_mean=None,
running_var=None,
save_mean=None,
save_invstd=None,
train=False,
eps=1e-05,
output_mask=None,
):
logging.debug("GEMS BATCHNORM BACKWARD")
input_3d = make_3d_for_bn(input)
output_grad_3d = make_3d_for_bn(grad_out)

@staticmethod
def backward(ctx, output_grad):
logging.debug("GEMS BATCHNORM BACKWARD")
(input, mean, inv_std, weight) = ctx.saved_tensors
input_3d = make_3d_for_bn(input)
output_grad_3d = make_3d_for_bn(output_grad)
batch_dim, feat_dim, spatial_dim = input_3d.shape

batch_dim, feat_dim, spatial_dim = input_3d.shape
if output_mask[0]:
input_grad = torch.empty_like(input_3d)

if ctx.affine:
weight_grad = torch.empty((feat_dim,), device=input.device)
bias_grad = torch.empty_like(weight_grad)

else:
weight_grad = bias_grad = None

# Launches 1D grid where each program operates over one feature.
with torch_device_fn.device(input.device):
batch_norm_backward_kernel[(feat_dim,)](
output_grad_3d,
input_3d,
mean,
inv_std,
weight,
input_grad,
weight_grad,
bias_grad,
batch_dim,
spatial_dim,
*output_grad_3d.stride(),
*input_3d.stride(),
*input_grad.stride(),
affine=ctx.affine,
)

# Pads output with None because a gradient is necessary for
# all input arguments.
return (
input_grad.view_as(input),
else:
input_grad = None
if output_mask[1]:
weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
else:
weight_grad = None
if output_mask[2]:
bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
else:
bias_grad = None

# Launches 1D grid where each program operates over one feature.
with torch_device_fn.device(input.device):
batch_norm_backward_kernel[(feat_dim,)](
output_grad_3d,
input_3d,
save_mean,
save_invstd,
weight,
input_grad,
weight_grad,
bias_grad,
None,
None,
None,
None,
None,
None,
batch_dim,
spatial_dim,
*output_grad_3d.stride(),
*input_3d.stride(),
*input_grad.stride(),
*output_mask,
)


def batch_norm(
input,
weight=None,
bias=None,
running_mean=None,
running_var=None,
training=False,
momentum=0.1,
eps=1e-05,
cudnn_enable=True,
):
return BatchNorm.apply(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
cudnn_enable,
# Pads output with None because a gradient is necessary for
# all input arguments.
return (
input_grad.view_as(input),
weight_grad,
bias_grad,
)
3 changes: 0 additions & 3 deletions src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,6 @@ batch_norm:
META: {}
num_warps: warps
warps:
- 1
- 2
- 4
- 8
- 16
- 32
Loading

0 comments on commit 01bee17

Please sign in to comment.