Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.2: drop supports for torch<2, vmap for jacobian/hessian #6

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy
numba
torch
torchaudio
torch>=2
torchlpc
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
long_description_content_type="text/markdown",
url="https://github.com/yoyololicon/torchcomp",
packages=["torchcomp"],
install_requires=["torch", "torchaudio", "torchlpc", "numpy", "numba"],
install_requires=["torch>=2", "torchlpc", "numpy", "numba"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
44 changes: 44 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch.func import jacfwd
import pytest
from torchcomp.core import compressor_core


from .test_grad import create_test_inputs


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_vmap(device: str):
batch_size = 4
samples = 128
x, zi, at, rt = tuple(x.to(device) for x in create_test_inputs(batch_size, samples))
y = torch.randn_like(x)

x.requires_grad = True
zi.requires_grad = True
at.requires_grad = True
rt.requires_grad = True

args = (x, zi, at, rt)

def func(x, zi, at, rt):
return F.mse_loss(compressor_core(x, zi, at, rt), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

loss = func(*args)
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)
10 changes: 4 additions & 6 deletions torchcomp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from typing import Union
from torchaudio.functional import lfilter
from torchlpc import sample_wise_lpc

from .core import compressor_core

Expand Down Expand Up @@ -88,11 +88,9 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
).broadcast_to(rms.shape[0])
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)

return lfilter(
rms,
torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1),
torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1),
False,
return sample_wise_lpc(
rms * avg_coef,
avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1,
)


Expand Down
82 changes: 47 additions & 35 deletions torchcomp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def compressor_cuda_kernel(
B: int,
T: int,
):
b = cuda.blockIdx.x
i = cuda.threadIdx.x
b: int = cuda.blockIdx.x
i: int = cuda.threadIdx.x

if b >= B or i > 0:
return
Expand Down Expand Up @@ -93,8 +93,8 @@ def compressor_cuda(
class CompressorFunction(Function):
@staticmethod
def forward(
ctx: Any, x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
) -> torch.Tensor:
x: torch.Tensor, zi: torch.Tensor, at: torch.Tensor, rt: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
y, at_mask = compressor_cuda(
x.detach(), zi.detach(), at.detach(), rt.detach()
Expand All @@ -108,19 +108,21 @@ def forward(
)
y = torch.from_numpy(y).to(x.device)
at_mask = torch.from_numpy(at_mask).to(x.device)
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
return y, at_mask

# for jvp
ctx.x = x
ctx.y = y
ctx.zi = zi
ctx.at = at
ctx.rt = rt
ctx.at_mask = at_mask
return y
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
x, zi, at, rt = inputs
y, at_mask = output
ctx.mark_non_differentiable(at_mask)
ctx.save_for_backward(x, y, zi, at, rt, at_mask)
ctx.save_for_forward(x, y, zi, at, rt, at_mask)
return ctx

@staticmethod
def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
def backward(
ctx: Any, grad_y: torch.Tensor, _
) -> Tuple[Optional[torch.Tensor], ...]:
x, y, zi, at, rt, at_mask = ctx.saved_tensors
grad_x = grad_zi = grad_at = grad_rt = None

Expand Down Expand Up @@ -153,19 +155,6 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
if ctx.needs_input_grad[3]:
grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1)

if hasattr(ctx, "y"):
del ctx.y
if hasattr(ctx, "x"):
del ctx.x
if hasattr(ctx, "zi"):
del ctx.zi
if hasattr(ctx, "at"):
del ctx.at
if hasattr(ctx, "rt"):
del ctx.rt
if hasattr(ctx, "at_mask"):
del ctx.at_mask

return grad_x, grad_zi, grad_at, grad_rt

@staticmethod
Expand All @@ -175,12 +164,13 @@ def jvp(
grad_zi: torch.Tensor,
grad_at: torch.Tensor,
grad_rt: torch.Tensor,
) -> torch.Tensor:
x, y, zi, at, rt, at_mask = ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
) -> Tuple[torch.Tensor, None]:
x, y, zi, at, rt, at_mask = ctx.saved_tensors
coeffs = torch.where(at_mask, at.unsqueeze(1), rt.unsqueeze(1))

fwd_x = 0 if grad_x is None else grad_x * coeffs

fwd_combined: torch.Tensor
if grad_at is None and grad_rt is None:
fwd_combined = fwd_x
else:
Expand All @@ -192,13 +182,35 @@ def jvp(
fwd_combined = fwd_x + grad_beta * (
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
)
return (
sample_wise_lpc(
fwd_combined,
coeffs.unsqueeze(2) - 1,
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
),
None,
)

del ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
return sample_wise_lpc(
fwd_combined,
coeffs.unsqueeze(2) - 1,
grad_zi if grad_zi is None else grad_zi.unsqueeze(1),
@staticmethod
def vmap(info, in_dims, *args):
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)

x, zi, at, rt = tuple(
map(
lambda x: x.reshape(-1, *x.shape[2:]),
map(maybe_expand_bdim_at_front, args, in_dims),
)
)

y, at_mask = CompressorFunction.apply(x, zi, at, rt)
return (
y.reshape(info.batch_size, -1, *y.shape[1:]),
at_mask.reshape(info.batch_size, -1, *at_mask.shape[1:]),
), 0


compressor_core: Callable = CompressorFunction.apply
def compressor_core(*args, **kwargs) -> torch.Tensor:
return CompressorFunction.apply(*args, **kwargs)[0]
Loading