-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential native torch.compile alternative #41
Comments
Thanks! Can you try benchmarking the backward pass? |
Also currently when you're measuring the time you're also counting the |
Thanks for the feedback! When I corrected the benchmarking code, the forward pass speed was still the same as the CUDA kernel, but the backwards pass was indeed 2-3x slower than the CUDA kernel, closer to the speed of torch.nn.Conv1d. Although, if I implement the backwards pass similar to how forward was implemented: class CausalConv1dUnrollFn(torch.autograd.Function):
@staticmethod
def forward(x, w):
out = 0
for i in range(w.shape[1]):
out += F.pad(x, (w.shape[1] - i - 1, -w.shape[1] + i + 1)) * w[:, i : i + 1]
return out
@staticmethod
def setup_context(ctx, inputs, output):
x, w = inputs
ctx.save_for_backward(x, w)
@staticmethod
def backward(ctx, grad_output):
x, w = ctx.saved_tensors
k = w.shape[1]
grad_weight = torch.empty_like(w)
for i in range(k):
grad_weight[:, k - i - 1] = (F.pad(grad_output, (-i, i)) * x).sum(
dim=(0, 2)
)
grad_x = 0
for i in range(k):
grad_x += F.pad(grad_output, (k - i - 1, -k + i + 1)) * w[:, i : i + 1]
return grad_x, grad_weight and if the entire forward + backward pass is compiled with the above implementation, i.e. def torch_unroll_fwd(x, w):
return F.silu(CausalConv1dUnrollFn.apply(x, w))
@torch.compile()
def torch_unroll_bwd(x, w):
out = torch_unroll_fwd(x, w)
loss_fn(out).backward() Then, the overall throughput is:
Though I haven't used |
One option is just to decorate the causal_conv1d functions so that it works w torch compile |
Hello,
I was playing around with torch inductor, and I tried compiling a simple implementation of causal conv1d:
On my hardware (A40) it looks like it has the same outputs and efficiency as causal_conv1d_fn.
This might be useful, since currently torch.compile with the default Mamba implementation will graph break on the causal_conv1d call (I'm not sure if this is always true or only an issue on my end, though).
Are there any disadvantages to doing this approach, instead of using the CUDA implementation here?
My testing code is below:
Outputs:
The text was updated successfully, but these errors were encountered: