Skip to content

Commit

Permalink
test: vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Sep 11, 2024
1 parent 611a596 commit 440b759
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch.func import jacfwd
import pytest
from torchcomp.core import compressor_core


from .test_grad import create_test_inputs


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_vmap(device: str):
batch_size = 4
samples = 128
x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples))
y = torch.randn_like(x)

x.requires_grad = True
zi.requires_grad = True
at.requires_grad = True
rt.requires_grad = True

args = (x, zi, at, rt)

def func(x, zi, at, rt):
return F.mse_loss(compressor_core(x, zi, at, rt), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

loss = func(*args)
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)

0 comments on commit 440b759

Please sign in to comment.