Skip to content

Commit

Permalink
Fix minor issues in Multi-Retraining layers
Browse files Browse the repository at this point in the history
  • Loading branch information
etrommer committed Jan 30, 2024
1 parent 69928e3 commit 709e7f0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
22 changes: 13 additions & 9 deletions src/torchapprox/layers/approx_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,19 @@ def mul_idx(self) -> Optional[int]:

@mul_idx.setter
def mul_idx(self, multi_idx: int):
if self._shadow_biases is None or self._shadow_luts is None:
raise ValueError(
"Multi-Retraining was not properly initialized. Call `init_shadow_luts()` first to set a list of LUTs."
)
if multi_idx >= len(self._shadow_luts):
raise ValueError(f"Bad index {multi_idx} for {len(self._shadow_luts)} LUTs")
self.bias = self._shadow_biases[multi_idx]
self.lut = self._shadow_luts[multi_idx]
self._mul_idx = multi_idx
if self._shadow_luts is not None:
assert multi_idx <= len(
self._shadow_luts
), f"Bad index {multi_idx} for {len(self._shadow_luts)} LUTs"
self.lut = self._shadow_luts[multi_idx]
self._mul_idx = multi_idx
if self._shadow_biases is not None:
assert multi_idx <= len(
self._shadow_biases
), f"Bad index {multi_idx} for {len(self._shadow_biases)} biases"
self.bias = self._shadow_biases[multi_idx]
self._mul_idx = multi_idx
return

@abstractmethod
def quant_fwd(
Expand Down
8 changes: 3 additions & 5 deletions src/torchapprox/utils/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import torch.ao.quantization as tq


def convert_batchnorms(
net: torch.nn.Module,
) -> torch.nn.Module:
def convert_batchnorms(net: torch.nn.Module, size: int) -> torch.nn.Module:
replace_list = []

def find_replacable_modules(parent_module):
for name, child_module in parent_module.named_children():
if isinstance(child_module, torch.nn.modules._NormBase):
if isinstance(child_module, torch.nn.modules.batchnorm._BatchNorm):
replace_list.append((parent_module, name))
for child in parent_module.children():
find_replacable_modules(child)
Expand All @@ -23,7 +21,7 @@ def find_replacable_modules(parent_module):

for parent, name in replace_list:
orig_layer = getattr(parent, name)
multi_norm = tal.MultiBatchNorm(orig_layer)
multi_norm = tal.MultiBatchNorm(orig_layer, size)
setattr(parent, name, multi_norm)
return net

Expand Down

0 comments on commit 709e7f0

Please sign in to comment.