Skip to content

Commit

Permalink
[Operator] register log_softmax backward
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Feb 6, 2025
1 parent 321b3b2 commit e2be412
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 73 deletions.
3 changes: 2 additions & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("any.dim", any_dim, Autograd.disable),
("any.dims", any_dims, Autograd.disable),
("quantile", quantile, Autograd.disable),
("log_softmax.int", log_softmax, Autograd.enable),
("_log_softmax", log_softmax, Autograd.disable),
("_log_softmax_backward_data", log_softmax_backward, Autograd.disable),
("outer", outer, Autograd.enable),
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
("nll_loss_forward", nll_loss_forward, Autograd.disable),
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .isnan import isnan
from .layernorm import layer_norm, layer_norm_backward
from .le import le, le_scalar
from .log_softmax import log_softmax
from .log_softmax import log_softmax, log_softmax_backward
from .logical_and import logical_and
from .logical_not import logical_not
from .logical_or import logical_or
Expand Down Expand Up @@ -278,6 +278,7 @@
"var_mean",
"vector_norm",
"log_softmax",
"log_softmax_backward",
"outer",
"cross_entropy_loss",
"where_self_out",
Expand Down
123 changes: 57 additions & 66 deletions src/flag_gems/ops/log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,73 +93,64 @@ def log_softmax_backward_kernel(
tl.store(in_grad_ptrs, in_grad, mask=mask)


class LogSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dim, dtype):
logging.debug("GEMS LOG_SOFTMAX")

assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
dim = dim % x.ndim
M = 1
N = x.shape[dim]
for i in range(dim):
M *= x.shape[i]
inp = x.contiguous()
if dtype is None:
dtype = x.dtype
out = torch.empty_like(inp, dtype=dtype)
K = inp.numel() // M // N

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
def log_softmax(self, dim, half_to_float=False):
logging.debug("GEMS LOG_SOFTMAX")

assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
dim = dim % self.ndim
M = 1
N = self.shape[dim]
for i in range(dim):
M *= self.shape[i]
inp = self.contiguous()
if half_to_float:
dtype = torch.float32
else:
dtype = self.dtype
out = torch.empty_like(inp, dtype=dtype)
K = inp.numel() // M // N

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch_device_fn.device(inp.device):
log_softmax_kernel[grid](
out,
inp,
M,
N,
K,
num_warps=8,
)
with torch_device_fn.device(inp.device):
log_softmax_kernel[grid](
out,
inp,
M,
N,
K,
num_warps=8,
)
ctx.save_for_backward(out)
ctx.dim = dim
return out

@staticmethod
def backward(ctx, out_grad):
logging.debug("GEMS LOG_SOFTMAX VJP")

dim = ctx.dim
(out,) = ctx.saved_tensors

assert dim >= -out.ndim and dim < out.ndim, "Invalid dim"
dim = dim % out.ndim
M = 1
N = out.shape[dim]
for i in range(dim):
M *= out.shape[i]

out_grad = out_grad.contiguous()
in_grad = torch.empty_like(out)
K = out.numel() // M // N

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
return out


def log_softmax_backward(grad_output, output, dim, input_dtype):
logging.debug("GEMS LOG_SOFTMAX VJP")

assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
dim = dim % output.ndim
M = 1
N = output.shape[dim]
for i in range(dim):
M *= output.shape[i]

grad_output = grad_output.contiguous()
in_grad = torch.empty_like(output, dtype=input_dtype)
K = output.numel() // M // N

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch_device_fn.device(in_grad.device):
log_softmax_backward_kernel[grid](
output,
grad_output,
in_grad,
M,
N,
K,
)
with torch_device_fn.device(in_grad.device):
log_softmax_backward_kernel[grid](
out,
out_grad,
in_grad,
M,
N,
K,
)
return in_grad, None, None


def log_softmax(x, dim=-1, dtype=None):
return LogSoftmax.apply(x, dim, dtype)
return in_grad
23 changes: 18 additions & 5 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,19 +339,32 @@ def test_accuracy_count_nonzero(shape, dtype):
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_log_softmax(shape, dtype):
dim = 1
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp, True)

ref_out = torch.nn.functional.log_softmax(ref_inp, dim=dim)
with flag_gems.use_gems():
res_out = torch.nn.functional.log_softmax(inp, dim=dim)
gems_assert_close(res_out, ref_out, dtype)

out_grad = torch.randn_like(res_out)
ref_grad = to_reference(out_grad, True)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
@pytest.mark.log_softmax
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_log_softmax_backward(shape, dtype):
res_grad = torch.randn(shape, dtype=dtype, device=flag_gems.device)
res_out = torch.randn_like(res_grad)
ref_grad = to_reference(res_grad, True)
ref_out = to_reference(res_out, True)
dim = 1

ref_in_grad = torch.ops.aten._log_softmax_backward_data(
ref_grad, ref_out, dim, ref_grad.dtype
)
with flag_gems.use_gems():
res_in_grad = torch.ops.aten._log_softmax_backward_data(
res_grad, res_out, dim, dtype
)
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim])


Expand Down

0 comments on commit e2be412

Please sign in to comment.