-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v0.2: drop supports for torch<2 (#6)
* feat: torch2 autograd scheme * drop support for torch<2 * feat: use torchlpc for fast avg, drop torchaudio dependency * feat: vmap * test: vmap
- Loading branch information
1 parent
d044041
commit 1525115
Showing
5 changed files
with
97 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
numpy | ||
numba | ||
torch | ||
torchaudio | ||
torch>=2 | ||
torchlpc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.func import jacfwd | ||
import pytest | ||
from torchcomp.core import compressor_core | ||
|
||
|
||
from .test_grad import create_test_inputs | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"device", | ||
[ | ||
"cpu", | ||
pytest.param( | ||
"cuda", | ||
marks=pytest.mark.skipif( | ||
not torch.cuda.is_available(), reason="CUDA not available" | ||
), | ||
), | ||
], | ||
) | ||
def test_vmap(device: str): | ||
batch_size = 4 | ||
samples = 128 | ||
x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples)) | ||
y = torch.randn_like(x) | ||
|
||
x.requires_grad = True | ||
zi.requires_grad = True | ||
at.requires_grad = True | ||
rt.requires_grad = True | ||
|
||
args = (x, zi, at, rt) | ||
|
||
def func(x, zi, at, rt): | ||
return F.mse_loss(compressor_core(x, zi, at, rt), y) | ||
|
||
jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) | ||
|
||
loss = func(*args) | ||
loss.backward() | ||
for jac, arg in zip(jacs, args): | ||
assert torch.allclose(jac, arg.grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters