Skip to content

Commit

Permalink
Support grad_to_none, properly benchmark layernorm backward
Browse files Browse the repository at this point in the history
Summary:
We don't want to benchmark gradient accumulation, which is why
triton's do_bench provides the `grad_to_none` argument, which sets the
associated tensor's grads to `None` outside the timed region.

This diff plumbs that through triton bench and uses it in layernorm

I also fixed an error in which the x input gradient included some extra
compute: the X input is `-2 + 0.5 * randn`, but you don't want to set
`requires_grad` on the `randn` or else you differentiate the `mul` and `add`
also, when it's just input preparation.

Reviewed By: chenyang78, sijiac

Differential Revision: D55882037

fbshipit-source-id: 73619731033f2f49fdf7872e47d0cba6f58329b3
  • Loading branch information
bertmaher authored and facebook-github-bot committed Apr 8, 2024
1 parent 5813012 commit d9a9600
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
11 changes: 9 additions & 2 deletions torchbenchmark/operators/layer_norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
register_metric,
)

from typing import Callable
from typing import Callable, List

from . import tutorial

Expand All @@ -28,15 +28,22 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
dy = 0.1 * torch.randn_like(y)
return lambda: y.backward(dy, retain_graph=True)

def get_grad_to_none(self, args) -> List[torch.Tensor]:
x = args[0]
return [x]

def get_input_iter(self):
M = 4096
eps = 1e-5
for N in [512 * i for i in range(2, 32)]:
x_shape = (M, N)
w_shape = (x_shape[-1],)
x = -2.3 + 0.5 * torch.randn(
x_shape, dtype=self.dtype, device="cuda", requires_grad=True
x_shape,
dtype=self.dtype,
device="cuda",
)
x.requires_grad_()
weight = torch.rand(
w_shape, dtype=self.dtype, device="cuda", requires_grad=True
)
Expand Down
10 changes: 9 additions & 1 deletion torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def get_input_iter(self) -> Generator:
"""Return the dynamic input iterator for the model."""
raise NotImplementedError("Each operator must implement its own input iterator.")

def get_grad_to_none(self, args):
return None

def plot(self):
"""Plot the comparison between different operator implementations."""
Expand Down Expand Up @@ -408,7 +410,13 @@ def _do_bench(self,
if baseline:
self.baseline_fn = fn
if set(["latency", "tflops", "speedup"]) & set(self.required_metrics):
latency = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
latency = triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
quantiles=quantiles,
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
if "speedup" in self.required_metrics:
speedup = numpy.median(self.baseline_metrics.latency) / numpy.median(latency) \
if self.baseline_metrics and self.baseline_metrics.latency else None
Expand Down

0 comments on commit d9a9600

Please sign in to comment.