Skip to content

Commit

Permalink
v0.2: drop supports for torch<2 (#6)
Browse files Browse the repository at this point in the history
* feat: torch2 autograd scheme

* drop support for torch<2

* feat: use torchlpc for fast avg, drop torchaudio dependency

* feat: vmap

* test: vmap
  • Loading branch information
yoyolicoris authored Sep 11, 2024
1 parent d044041 commit 1525115
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 44 deletions.
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy
numba
torch
torchaudio
torch>=2
torchlpc
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
long_description_content_type="text/markdown",
url="https://github.com/yoyololicon/torchcomp",
packages=["torchcomp"],
install_requires=["torch", "torchaudio", "torchlpc", "numpy", "numba"],
install_requires=["torch>=2", "torchlpc", "numpy", "numba"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
44 changes: 44 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch.func import jacfwd
import pytest
from torchcomp.core import compressor_core


from .test_grad import create_test_inputs


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_vmap(device: str):
batch_size = 4
samples = 128
x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples))
y = torch.randn_like(x)

x.requires_grad = True
zi.requires_grad = True
at.requires_grad = True
rt.requires_grad = True

args = (x, zi, at, rt)

def func(x, zi, at, rt):
return F.mse_loss(compressor_core(x, zi, at, rt), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

loss = func(*args)
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)
10 changes: 4 additions & 6 deletions torchcomp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from typing import Union
from torchaudio.functional import lfilter
from torchlpc import sample_wise_lpc

from .core import compressor_core

Expand Down Expand Up @@ -88,11 +88,9 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
).broadcast_to(rms.shape[0])
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)

return lfilter(
rms,
torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1),
torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1),
False,
return sample_wise_lpc(
rms * avg_coef,
avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1,
)


Expand Down
82 changes: 47 additions & 35 deletions torchcomp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def compressor_cuda_kernel(
B: int,
T: int,
):
b = cuda.blockIdx.x
i = cuda.threadIdx.x
b: int = cuda.blockIdx.x
i: int = cuda.threadIdx.x

if b >= B or i > 0:
return
Expand Down Expand Up @@ -93,8 +93,8 @@ def compressor_cuda(
class CompressorFunction(Function):
@staticmethod
def forward(
ctx: Any, x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
) -> torch.Tensor:
x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
y, at_mask = compressor_cuda(
x.detach(), zi.detach(), at.detach(), rt.detach()
Expand All @@ -108,19 +108,21 @@ def forward(
)
y = torch.from_numpy(y).to(x.device)
at_mask = torch.from_numpy(at_mask).to(x.device)
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
return y, at_mask

# for jvp
ctx.x = x
ctx.y = y
ctx.zi = zi
ctx.at = at
ctx.rt = rt
ctx.at_mask = at_mask
return y
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
x, zi, at, rt = inputs
y, at_mask = output
ctx.mark_non_differentiable(at_mask)
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
ctx.save_for_forward(x, y, zi, at, rt, at_mask)
return ctx

@staticmethod
def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
def backward(
ctx: Any, grad_y: torch.Tensor, _
) -> Tuple[Optional[torch.Tensor], ...]:
x, y, zi, at, rt, at_mask = ctx.saved_tensors
grad_x = grad_zi = grad_at = grad_rt = None

Expand Down Expand Up @@ -153,19 +155,6 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
if ctx.needs_input_grad[3]:
grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1)

if hasattr(ctx, "y"):
del ctx.y
if hasattr(ctx, "x"):
del ctx.x
if hasattr(ctx, "zi"):
del ctx.zi
if hasattr(ctx, "at"):
del ctx.at
if hasattr(ctx, "rt"):
del ctx.rt
if hasattr(ctx, "at_mask"):
del ctx.at_mask

return grad_x, grad_zi, grad_at, grad_rt

@staticmethod
Expand All @@ -175,12 +164,13 @@ def jvp(
grad_zi: torch.Tensor,
grad_at: torch.Tensor,
grad_rt: torch.Tensor,
) -> torch.Tensor:
x, y, zi, at, rt, at_mask = ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
) -> Tuple[torch.Tensor, None]:
x, y, zi, at, rt, at_mask = ctx.saved_tensors
coeffs = torch.where(at_mask, at.unsqueeze(1), rt.unsqueeze(1))

fwd_x = 0 if grad_x is None else grad_x * coeffs

fwd_combined: torch.Tensor
if grad_at is None and grad_rt is None:
fwd_combined = fwd_x
else:
Expand All @@ -192,13 +182,35 @@ def jvp(
fwd_combined = fwd_x + grad_beta * (
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
)
return (
sample_wise_lpc(
fwd_combined,
coeffs.unsqueeze(2) - 1,
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
),
None,
)

del ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
return sample_wise_lpc(
fwd_combined,
coeffs.unsqueeze(2) - 1,
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
@staticmethod
def vmap(info, in_dims, *args):
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)

x, zi, at, rt = tuple(
map(
lambda x: x.reshape(-1, *x.shape[2:]),
map(maybe_expand_bdim_at_front, args, in_dims),
)
)

y, at_mask = CompressorFunction.apply(x, zi, at, rt)
return (
y.reshape(info.batch_size, -1, *y.shape[1:]),
at_mask.reshape(info.batch_size, -1, *at_mask.shape[1:]),
), 0


compressor_core: Callable = CompressorFunction.apply
def compressor_core(*args, **kwargs) -> torch.Tensor:
return CompressorFunction.apply(*args, **kwargs)[0]

0 comments on commit 1525115

Please sign in to comment.