Skip to content

Commit

Permalink
WIP: Add multi-batchnorm and bias feature
Browse files Browse the repository at this point in the history
  • Loading branch information
etrommer committed Jan 18, 2024
1 parent 488f8d5 commit d7e3717
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/torchapprox/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
"ApproxLinear",
"ApproxWrapper",
"layer_mapping_dict",
"MultiBatchNorm",
]

from .approx_conv2d import ApproxConv2d
from .approx_layer import ApproxLayer, InferenceMode
from .approx_linear import ApproxLinear
from .approx_wrapper import ApproxWrapper, layer_mapping_dict
from .multi_batchnorm import MultiBatchNorm
44 changes: 41 additions & 3 deletions src/torchapprox/layers/approx_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# pylint: disable=missing-module-docstring
import copy
import enum
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, no_type_check, Union
from typing import TYPE_CHECKING, Callable, List, Optional, no_type_check, Union
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -58,7 +59,7 @@ class ApproxLayer(ABC):
"""

def __init__(
self, qconfig: Optional[tq.QConfig] = None, learnable_noise: bool = False
self,
):
self.inference_mode: InferenceMode = InferenceMode.QUANTIZED

Expand All @@ -73,11 +74,15 @@ def __init__(
)
self._mean: torch.Tensor = torch.tensor([0.0])

self._shadow_luts: Optional[List[npt.NDArray]] = None
self._shadow_biases: Optional[List[torch.Parameter]] = None
self._mul_idx: Optional[int] = None

@staticmethod
def default_qconfig() -> tq.QConfig:
act_qconfig = tq.FakeQuantize.with_args(
observer=tq.HistogramObserver,
dtype=torch.quint8,
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
quant_min=0,
quant_max=127,
Expand Down Expand Up @@ -173,6 +178,39 @@ def opcount(self) -> int:
forward pass of this layer
"""

def init_shadow_luts(self, luts: List[npt.NDArray]):
"""
Prepare layer for inference with multiple AMs
Args:
luts: List of LUTs, one for each multiplier
"""
assert len(luts) >= 1, "LUTs can't be empty"
self._shadow_biases = [copy.deepcopy(self.bias) for _ in range(len(luts))]
self._shadow_luts = luts
self.mul_idx = 0

@property
def mul_idx(self) -> Optional[int]:
"""Shadow Multiplier Index
Returns:
Index of the currently configured shadow multiplier
"""
return self._mul_idx

@mul_idx.setter
def mul_idx(self, multi_idx: int):
if self._shadow_biases is None or self._shadow_luts is None:
raise ValueError(
"Multi-Retraining was not properly initialized. Call `init_shadow_luts()` first to set a list of LUTs."
)
if multi_idx >= len(self._shadow_luts):
raise ValueError(f"Bad index {multi_idx} for {len(self._shadow_luts)} LUTs")
self.bias = self._shadow_biases[multi_idx]
self.lut = self._shadow_luts[multi_idx]
self._mul_idx = multi_idx

@abstractmethod
def quant_fwd(
self, x: torch.FloatTensor, w: torch.FloatTensor
Expand Down
31 changes: 31 additions & 0 deletions src/torchapprox/layers/multi_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional
import copy
import torch
from torch.nn.modules.batchnorm import _NormBase


class MultiBatchNorm(torch.nn.Module):
def __init__(self, batch_norm: _NormBase, size: int):
super().__init__()
assert size >= 1, "Need at least one forward dimension"
self._shadow_norms: List[_NormBase] = [
copy.deepcopy(batch_norm) for _ in range(size)
]
self._mul_idx: int = 0
self.mul_idx = 0
self.fwd_norm: Optional[_NormBase]

def forward(self, x):
return self.fwd_norm(x)

@property
def mul_idx(self):
return self._mul_idx

@mul_idx.setter
def mul_idx(self, new_idx: int):
assert new_idx < len(
self._shadow_norms
), f"Bad Index {new_idx} for size {len(self._shadow_norms)}"
self._mul_idx = new_idx
self.fwd_norm = self._shadow_norms[self._mul_idx]
21 changes: 21 additions & 0 deletions src/torchapprox/utils/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@
import torch.ao.quantization as tq


def convert_batchnorms(
net: torch.nn.Module,
) -> torch.nn.Module:
replace_list = []

def find_replacable_modules(parent_module):
for name, child_module in parent_module.named_children():
if isinstance(child_module, torch.nn.modules._NormBase):
replace_list.append((parent_module, name))
for child in parent_module.children():
find_replacable_modules(child)

find_replacable_modules(net)

for parent, name in replace_list:
orig_layer = getattr(parent, name)
multi_norm = tal.MultiBatchNorm(orig_layer)
setattr(parent, name, multi_norm)
return net


def wrap_quantizable(
net: torch.nn.Module,
wrappable_layers: Optional[List[tal.ApproxLayer]] = None,
Expand Down

0 comments on commit d7e3717

Please sign in to comment.