Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential native torch.compile alternative #41

Open
jxiong21029 opened this issue Dec 7, 2024 · 4 comments
Open

Potential native torch.compile alternative #41

jxiong21029 opened this issue Dec 7, 2024 · 4 comments

Comments

@jxiong21029
Copy link

Hello,

I was playing around with torch inductor, and I tried compiling a simple implementation of causal conv1d:

@torch.compile(fullgraph=True, dynamic=False)
def causal_conv1d_torch_unroll(x, w):
    out = 0
    for i in range(w.shape[1]):
        out = (
            out + F.pad(x, (w.shape[1] - i - 1, -w.shape[1] + i + 1)) * w[:, i : i + 1]
        )
    return F.silu(out)

On my hardware (A40) it looks like it has the same outputs and efficiency as causal_conv1d_fn.

This might be useful, since currently torch.compile with the default Mamba implementation will graph break on the causal_conv1d call (I'm not sure if this is always true or only an issue on my end, though).

Are there any disadvantages to doing this approach, instead of using the CUDA implementation here?

My testing code is below:

import time

import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn

torch.backends.cudnn.benchmark = True


@torch.compile(fullgraph=True, dynamic=False)
def causal_conv1d_torch(x, w):
    x = F.pad(x, (3, 0))
    out = F.conv1d(x, w.unsqueeze(1), groups=w.size(0))
    return F.silu(out)


@torch.compile(fullgraph=True, dynamic=False)
def causal_conv1d_torch_unroll(x, w):
    out = 0
    for i in range(w.shape[1]):
        out = (
            out + F.pad(x, (w.shape[1] - i - 1, -w.shape[1] + i + 1)) * w[:, i : i + 1]
        )
    return F.silu(out)


def causal_conv1d_cuda(x, w):
    return causal_conv1d_fn(x, w, activation="silu")


B = 16
L = 16384
C = 1536
K = 4
dtype = torch.bfloat16

test_dtype = torch.float32
x = torch.randn(B, C, L, device="cuda", dtype=test_dtype)
w = torch.randn(C, K, device="cuda", dtype=test_dtype)
print(
    f"torch_conv1d allclose {torch.allclose(causal_conv1d_torch(x, w), causal_conv1d_cuda(x, w))}\n"
    f"torch_conv1d max_diff {(causal_conv1d_torch(x, w) - causal_conv1d_cuda(x, w)).abs().max()}\n"
    f"torch_unroll allclose {torch.allclose(causal_conv1d_torch_unroll(x, w), causal_conv1d_cuda(x, w))}\n"
    f"torch_unroll max_diff {(causal_conv1d_torch_unroll(x, w) - causal_conv1d_cuda(x, w)).abs().max()}"
)

for i in range(200):
    if i == 20:
        start = time.time()
    x = torch.randn(B, C, L, device="cuda", dtype=dtype)
    w = torch.randn(C, K, device="cuda", dtype=dtype)
    causal_conv1d_torch(x, w)
    torch.cuda.synchronize()
torch_compiled_conv1d_sec = time.time() - start
print(
    f"torch_compiled_conv1d: {torch_compiled_conv1d_sec * 1e6 / 180:.0f} microsec/iter"
)


for i in range(200):
    if i == 20:
        start = time.time()
    x = torch.randn(B, C, L, device="cuda", dtype=dtype)
    w = torch.randn(C, K, device="cuda", dtype=dtype)
    causal_conv1d_torch_unroll(x, w)
    torch.cuda.synchronize()
torch_compiled_unroll_sec = time.time() - start
print(
    f"torch_compiled_unroll: {torch_compiled_unroll_sec * 1e6 / 180:.0f} microsec/iter"
)


for i in range(200):
    if i == 20:
        start = time.time()
    x = torch.randn(B, C, L, device="cuda", dtype=dtype)
    w = torch.randn(C, K, device="cuda", dtype=dtype)
    causal_conv1d_cuda(x, w)
    torch.cuda.synchronize()

cuda_causal_conv1d_sec = time.time() - start
print(f"cuda: {cuda_causal_conv1d_sec * 1e6 / 180:.0f} microsec/iter")

print(
    f"cuda vs conv1d: {torch_compiled_conv1d_sec / cuda_causal_conv1d_sec:.2%} throughput\n"
    f"cuda vs unroll: {torch_compiled_unroll_sec / cuda_causal_conv1d_sec:.2%} throughput"
)

Outputs:

torch_conv1d allclose True
torch_conv1d max_diff 0.0
torch_unroll allclose True
torch_unroll max_diff 0.0
torch_compiled_conv1d: 16875 microsec/iter
torch_compiled_unroll: 4710 microsec/iter
cuda: 4561 microsec/iter
cuda vs conv1d: 369.98% throughput
cuda vs unroll: 103.26% throughput
@tridao
Copy link
Contributor

tridao commented Dec 7, 2024

Thanks! Can you try benchmarking the backward pass?

@tridao
Copy link
Contributor

tridao commented Dec 7, 2024

Also currently when you're measuring the time you're also counting the torch.randn. Can you try only measuring the conv1d and not the torch.randn?

@jxiong21029
Copy link
Author

Thanks for the feedback!

When I corrected the benchmarking code, the forward pass speed was still the same as the CUDA kernel, but the backwards pass was indeed 2-3x slower than the CUDA kernel, closer to the speed of torch.nn.Conv1d.

Although, if I implement the backwards pass similar to how forward was implemented:

class CausalConv1dUnrollFn(torch.autograd.Function):
    @staticmethod
    def forward(x, w):
        out = 0
        for i in range(w.shape[1]):
            out += F.pad(x, (w.shape[1] - i - 1, -w.shape[1] + i + 1)) * w[:, i : i + 1]
        return out

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, w = inputs
        ctx.save_for_backward(x, w)

    @staticmethod
    def backward(ctx, grad_output):
        x, w = ctx.saved_tensors

        k = w.shape[1]
        grad_weight = torch.empty_like(w)
        for i in range(k):
            grad_weight[:, k - i - 1] = (F.pad(grad_output, (-i, i)) * x).sum(
                dim=(0, 2)
            )

        grad_x = 0
        for i in range(k):
            grad_x += F.pad(grad_output, (k - i - 1, -k + i + 1)) * w[:, i : i + 1]

        return grad_x, grad_weight

and if the entire forward + backward pass is compiled with the above implementation, i.e.

def torch_unroll_fwd(x, w):
    return F.silu(CausalConv1dUnrollFn.apply(x, w))


@torch.compile()
def torch_unroll_bwd(x, w):
    out = torch_unroll_fwd(x, w)
    loss_fn(out).backward()

Then, the overall throughput is:

  • much faster than nn.Conv1d, even with compilation
  • competitive with using the CUDA kernel with the rest of the forward + backward pass run in eager mode
  • but still 20% slower than using the CUDA kernel if the forward + backward pass is torch.compiled (which does result in a graph break, but is still faster).

Though I haven't used torch.autograd.Function before, so this implementation may be suboptimal.

@tridao
Copy link
Contributor

tridao commented Dec 9, 2024

One option is just to decorate the causal_conv1d functions so that it works w torch compile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants