Skip to content

Commit

Permalink
Change tracing interface
Browse files Browse the repository at this point in the history
This removes the old tracing interface that relied on torch execution
hooks due to a lack of flexiblity. Torch hooks can only be set on an
`nn.Module` instance, which made the class hierarchy unnecessarily
complicated.

This change instead provides an internal tracing function that can be
triggered by populating the `traced_results` member of an `ApproxLayer`
instance. The tracing is then done within the `torch.autograd.Function`,
right before GeMM is called.
  • Loading branch information
etrommer committed Dec 20, 2023
1 parent 87185c7 commit 030fecf
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 106 deletions.
3 changes: 2 additions & 1 deletion src/torchapprox/layers/approx_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def approx_fwd(self, x_q, w_q, quant_params: QuantizationParameters):
self.conv_args,
self.htp_model,
self.output_dims(x_q),
self.approx_op.lut,
self.lut,
self.traced_inputs,
)

return y
Expand Down
69 changes: 64 additions & 5 deletions src/torchapprox/layers/approx_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import enum
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, no_type_check
from typing import TYPE_CHECKING, Callable, Optional, no_type_check, Union
from dataclasses import dataclass

import torch
import torch.ao.quantization as tq

from torchapprox.operators import LUTGeMM
import numpy as np
import numpy.typing as npt

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -37,15 +37,36 @@ class QuantizationParameters:
w_zero_point: torch.FloatTensor


@dataclass
class TracedGeMMInputs:
features: Optional[torch.FloatTensor]
weights: Optional[torch.FloatTensor]

def trace(self, x_q: torch.Tensor, w_q: torch.Tensor):
if self.features is None:
self.features = x_q.detach().cpu().float()
else:
self.features = torch.cat([self.features, x_q])

if self.weights is None:
self.weights = w_q.detach().cpu().float()


class ApproxLayer(ABC):
"""
Derivable Abstract Base Class for implementing Approximate Neural Network layers
"""

def __init__(self, qconfig: Optional[tq.QConfig] = None):
self.approx_op: LUTGeMM = LUTGeMM()
def __init__(
self, qconfig: Optional[tq.QConfig] = None, learnable_noise: bool = False
):
self.inference_mode: InferenceMode = InferenceMode.QUANTIZED

self._lut: Optional[torch.ShortTensor] = None
self.lut = self.accurate_lut()

self.htp_model: Optional[Callable] = None
self.traced_inputs: Optional[TracedGeMMInputs] = None

self._stdev: torch.Tensor = torch.tensor([0.0])
self._mean: torch.Tensor = torch.tensor([0.0])
Expand All @@ -68,6 +89,44 @@ def default_qconfig() -> tq.QConfig:
)
return tq.QConfig(activation=act_qconfig, weight=weight_qconfig)

@staticmethod
def accurate_lut() -> npt.NDArray[np.int32]:
x = np.arange(256)
x[x >= 128] -= 256
xx, yy = np.meshgrid(x, x)
return (xx * yy).astype(np.int32)

@property
def lut(self) -> torch.Tensor:
"""
The Lookup table to use for approximate multiplication. LUT can be:
- `None`: An accurate product is used internall. This is much faster than passing
operands through LUT kernels. Functionally equivalent to running the layer in
`quant` mode, but useful when the unfolded inputs/outputs need to be traced at runtime.
- `torch.Tensor` or `numpy.array`:
- 2D array of size 256x256 is required. Unused entries will be ignored when simulating
multiplication where the operand width is less than 8 Bit
- When supplying a `torch.Tensor` the datatype needs to be signed 16-Bit.
"""
return self._lut

@lut.setter
def lut(self, new_lut: Union[np.ndarray, torch.Tensor]):
assert len(new_lut.shape) == 2, "LUT needs to be 2D square matrix"
assert (
new_lut.shape[0] == new_lut.shape[1] == 256
), "Only 8x8 Bit LUTs are currently supported."

if isinstance(new_lut, torch.Tensor):
assert new_lut.dtype == torch.int, "LUT needs to be signed 32 Bit Integer"
self._lut = new_lut
elif isinstance(new_lut, np.ndarray):
self._lut = torch.from_numpy(new_lut).contiguous().int()
else:
raise ValueError(
f"Unknown LUT input type: {type(new_lut)}, supported types: torch.Tensor, np.ndarray"
)

@property
def stdev(self) -> float:
"""
Expand Down
11 changes: 9 additions & 2 deletions src/torchapprox/layers/approx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from .approx_layer import ApproxLayer, QuantizationParameters
from torchapprox.operators.approxgemm import ApproxGeMM


class ApproxLinear(ApproxLayer, QATLinear):
Expand Down Expand Up @@ -43,5 +44,11 @@ def quant_fwd(self, x, w):
return torch.nn.functional.linear(x, w)

def approx_fwd(self, x, w, quant_params: QuantizationParameters):
y = self.approx_op(x, w, quant_params, self.htp_model)
return y
return ApproxGeMM.apply(
x,
w,
self.lut,
quant_params,
self.htp_model,
self.traced_inputs,
)
3 changes: 1 addition & 2 deletions src/torchapprox/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Low-level NN operator implementations for GPU & CPU
"""
__all__ = ["LUTGeMM", "ApproxConv2dOp"]
__all__ = ["ApproxConv2dOp"]

from .lut import LUTGeMM
from .conv2d import ApproxConv2dOp
19 changes: 17 additions & 2 deletions src/torchapprox/operators/approxgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torchapprox.operators.backend import approx

if TYPE_CHECKING:
from torchapprox.layers.approx_layer import QuantizationParameters
from torchapprox.layers.approx_layer import (
QuantizationParameters,
TracedGeMMInputs,
)


class ApproxGeMM(torch.autograd.Function):
Expand All @@ -23,6 +26,7 @@ def forward( # type: ignore
lut: torch.Tensor,
quant_params: "QuantizationParameters",
htp_model: Optional[Callable],
traced_inputs: Optional["TracedGeMMInputs"],
) -> torch.Tensor:
"""
Approximate forward operation
Expand All @@ -35,6 +39,9 @@ def forward( # type: ignore
(w / quant_params.w_scale[:, None]) + quant_params.w_zero_point[:, None]
).T

if traced_inputs:
traced_inputs.trace(x_q, w_q)

if htp_model is None:
y_q = approx(x_q.char(), w_q.char(), lut).float()
else:
Expand All @@ -59,6 +66,7 @@ def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
_,
_,
_,
_,
) = inputs
ctx.save_for_backward(x, w)

Expand All @@ -80,4 +88,11 @@ def backward(ctx, grad_output):
# grad_a = torch.sum(torch.matmul(grad, b.transpose(1, 2)), axis=0)
# grad_b = torch.matmul(grad.transpose(1, 2), a).transpose(1, 2)

return grad_x, grad_w, None, None, None, None, None
return (
grad_x,
grad_w,
None,
None,
None,
None,
)
23 changes: 13 additions & 10 deletions src/torchapprox/operators/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union, TYPE_CHECKING

if TYPE_CHECKING:
from torchapprox.layers.approx_layer import QuantizationParameters
from torchapprox.layers.approx_layer import QuantizationParameters, TracedGeMMInputs

import torch

Expand Down Expand Up @@ -52,10 +52,6 @@ def use_fast_dwconv(self) -> bool:
- False otherwise
"""

# if not self.weight.is_cuda:
# return False
# if self.approx_op.lut is None:
# return False
if self.dilation[0] > 1 or self.dilation[1] > 1:
return False
if self.groups != self.in_channels:
Expand Down Expand Up @@ -158,6 +154,7 @@ def _im2col_conv2d(
conv_args: Conv2dArgs,
lut: torch.ShortTensor,
out_dims: Tuple[int, int],
traced_inputs: Optional["TracedGeMMInputs"],
) -> torch.FloatTensor:
# Pre-allocate output tensor
y_q = torch.empty(
Expand Down Expand Up @@ -196,6 +193,10 @@ def _im2col_conv2d(
conv_args.out_channels // conv_args.groups, -1
)

if traced_inputs:
assert conv_args.groups == 1, "Tracing of depthwise Conv2D is not supported"
traced_inputs.trace(x_unfold_s8, w_flat_s8)

# ApproxGeMM
y_q[:, out_ch_lower:out_ch_upper] = approx(
w_flat_s8,
Expand All @@ -219,25 +220,27 @@ def forward(
htp_model: Optional[Callable],
out_dims: Tuple[int, int],
lut: torch.ShortTensor,
traced_inputs: Optional["TracedGeMMInputs"],
):
x_q = torch.round((x / quant_params.x_scale) + quant_params.x_zero_point)
w_q = torch.round(
(w / quant_params.w_scale[:, None, None, None])
+ quant_params.w_zero_point[:, None, None, None]
)

if htp_model is not None:
trace = traced_inputs is not None
if htp_model is not None and not trace:
# HTP model
y_q = htp_model(
torch.nn.functional.conv2d, x_q, w_q, conv_args.backward_args()
)
torch.round(y_q)
elif conv_args.use_fast_dwconv() and x.is_cuda and w.is_cuda:
elif (conv_args.use_fast_dwconv() and x.is_cuda and w.is_cuda) and not trace:
# Depthwise Conv CUDA Kernel
y_q = dwconv2d(x_q, w_q, lut, conv_args.stride, conv_args.padding)
else:
# im2col & gemm kernel (supports CPU & GPU)
y_q = _im2col_conv2d(x_q, w_q, conv_args, lut, out_dims)
y_q = _im2col_conv2d(x_q, w_q, conv_args, lut, out_dims, traced_inputs)

if quant_params.x_zero_point == 0 and torch.all(quant_params.w_zero_point == 0):
y_q = _symmetric_requantize(y_q, quant_params)
Expand All @@ -262,7 +265,7 @@ def forward(

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
x, w, _, conv_args, _, _, _ = inputs
x, w, _, conv_args, _, _, _, _ = inputs
ctx.save_for_backward(x, w)
ctx.conf = conv_args.backward_args()

Expand All @@ -273,4 +276,4 @@ def backward(ctx, grad):
grad_input, grad_weight = _conv_bwd_ste(
grad, x, w, conf, ctx.needs_input_grad[0], ctx.needs_input_grad[1]
)
return grad_input, grad_weight, None, None, None, None, None, None, None
return grad_input, grad_weight, None, None, None, None, None, None, None, None
82 changes: 0 additions & 82 deletions src/torchapprox/operators/lut.py

This file was deleted.

4 changes: 2 additions & 2 deletions test/test_approx_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_compile(device, lut):
x = torch.rand(128, 42).requires_grad_()
quant.prepare_qat(w, {torch.nn.Linear: tal.ApproxLinear}, inplace=True)

w.wrapped.approx_op.lut = lut
w.wrapped.lut = lut
w.wrapped.inference_mode = tal.InferenceMode.APPROXIMATE
w_comp = torch.compile(w)
w_comp(x)
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_layer_empty_lut(device, layer):
)

layer.inference_mode = tal.InferenceMode.APPROXIMATE
layer.approx_op.lut = np.zeros((256, 256))
layer.lut = np.zeros((256, 256))

x = torch.randint(-128, 128, size=input_dims, device=device, dtype=torch.float32)
res = layer(
Expand Down

0 comments on commit 030fecf

Please sign in to comment.