diff --git a/requirements.txt b/requirements.txt index b53af3f..bcea830 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ numpy numba -torch -torchaudio +torch>=2 torchlpc diff --git a/setup.py b/setup.py index fa2602d..e2e3dca 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ long_description_content_type="text/markdown", url="https://github.com/yoyololicon/torchcomp", packages=["torchcomp"], - install_requires=["torch", "torchaudio", "torchlpc", "numpy", "numba"], + install_requires=["torch>=2", "torchlpc", "numpy", "numba"], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/test_vmap.py b/tests/test_vmap.py new file mode 100644 index 0000000..951f056 --- /dev/null +++ b/tests/test_vmap.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F +from torch.func import jacfwd +import pytest +from torchcomp.core import compressor_core + + +from .test_grad import create_test_inputs + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_vmap(device: str): + batch_size = 4 + samples = 128 + x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples)) + y = torch.randn_like(x) + + x.requires_grad = True + zi.requires_grad = True + at.requires_grad = True + rt.requires_grad = True + + args = (x, zi, at, rt) + + def func(x, zi, at, rt): + return F.mse_loss(compressor_core(x, zi, at, rt), y) + + jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) + + loss = func(*args) + loss.backward() + for jac, arg in zip(jacs, args): + assert torch.allclose(jac, arg.grad) diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index cbd6093..e7e216c 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from typing import Union -from torchaudio.functional import lfilter +from torchlpc import sample_wise_lpc from .core import compressor_core @@ -88,11 +88,9 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]): ).broadcast_to(rms.shape[0]) assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1) - return lfilter( - rms, - torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1), - torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1), - False, + return sample_wise_lpc( + rms * avg_coef, + avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1, ) diff --git a/torchcomp/core.py b/torchcomp/core.py index f1b3b60..44fba2b 100644 --- a/torchcomp/core.py +++ b/torchcomp/core.py @@ -18,8 +18,8 @@ def compressor_cuda_kernel( B: int, T: int, ): - b = cuda.blockIdx.x - i = cuda.threadIdx.x + b: int = cuda.blockIdx.x + i: int = cuda.threadIdx.x if b >= B or i > 0: return @@ -93,8 +93,8 @@ def compressor_cuda( class CompressorFunction(Function): @staticmethod def forward( - ctx: Any, x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor - ) -> torch.Tensor: + x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: if x.is_cuda: y, at_mask = compressor_cuda( x.detach(), zi.detach(), at.detach(), rt.detach() @@ -108,19 +108,21 @@ def forward( ) y = torch.from_numpy(y).to(x.device) at_mask = torch.from_numpy(at_mask).to(x.device) - ctx.save_for_backward(x, y, zi, at, rt, at_mask) + return y, at_mask - # for jvp - ctx.x = x - ctx.y = y - ctx.zi = zi - ctx.at = at - ctx.rt = rt - ctx.at_mask = at_mask - return y + @staticmethod + def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any: + x, zi, at, rt = inputs + y, at_mask = output + ctx.mark_non_differentiable(at_mask) + ctx.save_for_backward(x, y, zi, at, rt, at_mask) + ctx.save_for_forward(x, y, zi, at, rt, at_mask) + return ctx @staticmethod - def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: + def backward( + ctx: Any, grad_y: torch.Tensor, _ + ) -> Tuple[Optional[torch.Tensor], ...]: x, y, zi, at, rt, at_mask = ctx.saved_tensors grad_x = grad_zi = grad_at = grad_rt = None @@ -153,19 +155,6 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], .. if ctx.needs_input_grad[3]: grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1) - if hasattr(ctx, "y"): - del ctx.y - if hasattr(ctx, "x"): - del ctx.x - if hasattr(ctx, "zi"): - del ctx.zi - if hasattr(ctx, "at"): - del ctx.at - if hasattr(ctx, "rt"): - del ctx.rt - if hasattr(ctx, "at_mask"): - del ctx.at_mask - return grad_x, grad_zi, grad_at, grad_rt @staticmethod @@ -175,12 +164,13 @@ def jvp( grad_zi: torch.Tensor, grad_at: torch.Tensor, grad_rt: torch.Tensor, - ) -> torch.Tensor: - x, y, zi, at, rt, at_mask = ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask + ) -> Tuple[torch.Tensor, None]: + x, y, zi, at, rt, at_mask = ctx.saved_tensors coeffs = torch.where(at_mask, at.unsqueeze(1), rt.unsqueeze(1)) fwd_x = 0 if grad_x is None else grad_x * coeffs + fwd_combined: torch.Tensor if grad_at is None and grad_rt is None: fwd_combined = fwd_x else: @@ -192,13 +182,35 @@ def jvp( fwd_combined = fwd_x + grad_beta * ( x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1) ) + return ( + sample_wise_lpc( + fwd_combined, + coeffs.unsqueeze(2) - 1, + grad_zi if grad_zi is None else grad_zi.unsqueeze(1), + ), + None, + ) - del ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask - return sample_wise_lpc( - fwd_combined, - coeffs.unsqueeze(2) - 1, - grad_zi if grad_zi is None else grad_zi.unsqueeze(1), + @staticmethod + def vmap(info, in_dims, *args): + def maybe_expand_bdim_at_front(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x, zi, at, rt = tuple( + map( + lambda x: x.reshape(-1, *x.shape[2:]), + map(maybe_expand_bdim_at_front, args, in_dims), + ) ) + y, at_mask = CompressorFunction.apply(x, zi, at, rt) + return ( + y.reshape(info.batch_size, -1, *y.shape[1:]), + at_mask.reshape(info.batch_size, -1, *at_mask.shape[1:]), + ), 0 + -compressor_core: Callable = CompressorFunction.apply +def compressor_core(*args, **kwargs) -> torch.Tensor: + return CompressorFunction.apply(*args, **kwargs)[0]