Skip to content

Commit

Permalink
MMP: Remove hardcodings, query from config instead (#3844)
Browse files Browse the repository at this point in the history
Signed-off-by: yathindra kota <[email protected]>
  • Loading branch information
quic-ykota authored Feb 27, 2025
1 parent dd3bdbb commit 105a4cf
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common.utils import AimetLogger, log_with_error_and_assert_if_false
from aimet_torch.utils import is_leaf_module
from aimet_torch.utils import is_leaf_module, get_param_channel_axis


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down Expand Up @@ -143,14 +143,8 @@ def enable_per_channel_quantization(self):
Changes all parameter quantizers (if any) to per-channel mode.
"""
for param_name, param_quantizer in self.param_quantizers.items():
channel_axis = 0
if isinstance(self._module_to_wrap, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
channel_axis = 1 if param_name == 'weight' else 0

# pylint: disable = protected-access
param_quantizer.channel_axis = channel_axis
param_quantizer.channel_axis = get_param_channel_axis(self._module_to_wrap, param_name)

@staticmethod
def forward(_):
Expand All @@ -168,10 +162,11 @@ class LazyQuantizer(ABC):
"""
Quantizer builder class for supporting both v1 and v2 blocks
"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes, too-many-arguments
def __init__(self, bitwidth: int, round_mode, quant_scheme: QuantScheme,
use_symmetric_encodings: bool, enabled_by_default: bool,
data_type: QuantizationDataType = QuantizationDataType.int):
data_type: QuantizationDataType = QuantizationDataType.int, input_shape: tuple = None,
ch_axis: int = None):
self.round_mode = MAP_ROUND_MODE_TO_PYMO[round_mode]
self.quant_scheme = quant_scheme
self.use_symmetric_encodings = use_symmetric_encodings
Expand All @@ -185,8 +180,8 @@ def __init__(self, bitwidth: int, round_mode, quant_scheme: QuantScheme,
self.is_parm = False
self.is_singleton = False
self._encoding_min_max_fixed_vals = None
self.input_tensor_shape = None # None indicates unknown
self.channel_axis = None
self.input_tensor_shape = input_shape # None indicates unknown
self.channel_axis = ch_axis

@property
def encoding_min_max_fixed_vals(self) -> Optional[Tuple[float, float]]:
Expand Down
15 changes: 15 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,3 +1138,18 @@ def __getattr__(name: str):
else:
msg = f"module '{__name__}' has no attribute '{name}'"
raise AttributeError(msg) from e


def get_param_channel_axis(module: torch.nn.Module, param_name: str):
"""
Given a module and its param name, this method returns the channel axis of the given parameter.
:param module: torch.nn.Module
:param param_name: str representing the name of the parameter
"""
channel_axis = 0
if isinstance(module, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
channel_axis = 1 if param_name == 'weight' else 0
return channel_axis
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from typing import overload, Union, List, Tuple, Dict, get_args, Type, Optional, IO
import torch

from aimet_common.quantsim_config.json_config_importer import JsonConfigImporter
from aimet_common.utils import AimetLogger
from aimet_torch.v2.utils import flatten_list
from aimet_torch.v2.mixed_precision.utils import UserRequest, RequestType, SupportedDType, ModuleProduct, broadcast_tuples
Expand Down Expand Up @@ -68,7 +69,8 @@ def __init__(self, sim: QuantizationSimModel):
"""
self._sim = sim
self.user_requests = []
self.mp_handler = MpHandler(sim)
# pylint: disable=protected-access
self.mp_handler = MpHandler(sim, JsonConfigImporter.import_json_config_file(self._sim._config_file))

def _store_user_request(self, request_type: RequestType, module: Union[torch.nn.Module, Type, ModuleProduct],
activation: Union[List[SupportedDType], SupportedDType] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@
# pylint: disable=logging-fstring-interpolation

import copy
import functools
from typing import Dict, List, Tuple, Optional, Union, IO

import torch.nn

from aimet_common.defs import QuantizationDataType, QuantScheme
from aimet_common.utils import AimetLogger
from aimet_torch.onnx_utils import map_torch_types_to_onnx
from aimet_torch.utils import get_param_channel_axis
from aimet_torch.v2.nn.modules.custom import QuantizedConcat
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantsim import QuantizationSimModel
Expand All @@ -65,7 +68,13 @@ class MpHandler:
requests and apply to the sim
"""
def __init__(self, sim: QuantizationSimModel):
def __init__(self, sim: QuantizationSimModel, configs: dict):
"""
:param sim: QuantSim object
:param configs: configs parsed from the config file
"""
self._sim = sim
self._configs = configs
self.cg_traverser = ConnectedGraphTraverser(sim)
self.mp_requests = {}

Expand Down Expand Up @@ -268,7 +277,9 @@ def validate_supported_kernels_for_module(module, input_activation, param) -> bo
return mp_requests

@staticmethod
def _apply_request_to_quantizer(quantizer: QuantizerBase, candidate: Precision):
def _apply_request_to_quantizer(quantizer: QuantizerBase, candidate: Precision, quant_scheme: QuantScheme,
symm: bool, round_mode: str = 'nearest', tensor_shape: tuple = None,
ch_axis: int = None):
"""
Helper function to apply mixed precision candidate to a quantizer
:param quantizer: quantizer object
Expand All @@ -278,12 +289,13 @@ def _apply_request_to_quantizer(quantizer: QuantizerBase, candidate: Precision):
if not isinstance(quantizer, FloatQuantizeDequantize):
# convert to float QDQ
quantizer = _V2LazyQuantizer(candidate.bitwidth,
'nearest',
QuantScheme.post_training_tf,
quantizer.symmetric,
round_mode,
quant_scheme,
symm,
enabled_by_default=True,
data_type=QuantizationDataType.float
).realize()
data_type=QuantizationDataType.float,
input_shape=tensor_shape,
ch_axis=ch_axis).realize()

if candidate.bitwidth == 16:
quantizer.exponent_bits = 5
Expand All @@ -297,12 +309,13 @@ def _apply_request_to_quantizer(quantizer: QuantizerBase, candidate: Precision):
if isinstance(quantizer, FloatQuantizeDequantize):
# convert to int QDQ
quantizer = _V2LazyQuantizer(candidate.bitwidth,
'nearest',
QuantScheme.post_training_tf,
quantizer.symmetric,
round_mode,
quant_scheme,
symm,
enabled_by_default=True,
data_type=QuantizationDataType.int
).realize()
data_type=QuantizationDataType.int,
input_shape=tensor_shape,
ch_axis=ch_axis).realize()

quantizer.bitwidth = candidate.bitwidth

Expand Down Expand Up @@ -487,33 +500,75 @@ def _resolve_request_outputs_helper(module):
self._log_mp_requests(mp_requests, "Mixed Precision Requests After Propagation", log_file)
return mp_requests

@functools.cached_property
def _get_param_is_symm_fields(self):
"""Generates dict of {op_name: is_symmetric} fields corresponding to the weight parameter"""
is_symm_fields = {'defaults': self._configs["defaults"]["params"].get("is_symmetric", False)}
for op_name, settings in self._configs["op_type"].items():
if settings.get('params', {}).get('weight', {}).get('is_symmetric') is not None:
is_symm_fields[op_name] = settings['params']['weight']['is_symmetric']

return is_symm_fields

@functools.cached_property
def _get_param_pcq_mapping(self):
"""Generates dict of {op_name: per_channel_quantization} fields corresponding to the weight parameter"""

pcq_fields = {'defaults': self._configs["defaults"].get("per_channel_quantization", False)}
for op_name, settings in self._configs["op_type"].items():
if settings.get('per_channel_quantization', None) is not None:
pcq_fields[op_name] = settings['per_channel_quantization']

return pcq_fields


def _apply_requests_to_sim(self, mp_requests: Dict):
"""
Apply MP configuration to the sim object
:param mp_requests: MP requests after preprocessing, applying backend awareness(if present), propagating to
parent modules
"""
#pylint: disable=protected-access

for module, request in mp_requests.items():
if request.input_candidates:
assert len(module.input_quantizers) == len(request.input_candidates)
for idx, qtzr in enumerate(module.input_quantizers):
if request.input_candidates[idx] and qtzr:
module.input_quantizers[idx] = self._apply_request_to_quantizer(qtzr,
request.input_candidates[idx])
request.input_candidates[idx],
self._sim._quant_scheme,
False,
self._sim._rounding_mode)

if request.param_candidate:
assert all(param_key in module.param_quantizers for param_key in request.param_candidate.keys())
for param_key, param_candidate in request.param_candidate.items():
if param_candidate and param_key in module.param_quantizers.keys() and module.param_quantizers[param_key]:
module_type = map_torch_types_to_onnx.get(module.qcls_to_cls[type(module)], [None])[0]

ch_axis = None
if self._get_param_pcq_mapping.get(module_type, self._get_param_pcq_mapping.get("defaults")):
ch_axis = get_param_channel_axis(module, param_key)

module.param_quantizers[param_key] = \
self._apply_request_to_quantizer(module.param_quantizers[param_key], param_candidate)
self._apply_request_to_quantizer(module.param_quantizers[param_key], param_candidate,
self._sim._quant_scheme,
self._get_param_is_symm_fields.get(module_type,
self._get_param_is_symm_fields.get("defaults")),
self._sim._rounding_mode,
tensor_shape=tuple(module.__getattr__(param_key).shape), #pylint: disable=unnecessary-dunder-call
ch_axis=ch_axis)

if request.output_candidates:
assert len(module.output_quantizers) == len(request.output_candidates)
for idx, qtzr in enumerate(module.output_quantizers):
if request.output_candidates[idx] and qtzr:
module.output_quantizers[idx] = self._apply_request_to_quantizer(qtzr,
request.output_candidates[idx])
request.output_candidates[idx],
self._sim._quant_scheme,
False,
self._sim._rounding_mode)

def _resolve_contentions_at_module(self, current_module, mp_request, visited_modules, mp_requests: Dict,
strict: bool = True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from torch import nn

from aimet_common.defs import QuantizationDataType

from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.quantization.affine import QuantizeDequantize
from aimet_torch.v2.quantization.base.quantizer import QuantizerBase
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.mixed_precision import MixedPrecisionConfigurator, SupportedDType, Precision
Expand Down Expand Up @@ -432,6 +435,8 @@ def test_mp_6(self):
if 'weight' in m.param_quantizers:
assert request.param_candidate == {'weight': Precision(QuantizationDataType.float, 16)}



@pytest.mark.parametrize("candidate, qsim_bw", [('int16', 8), ('fp16', 8), ('fp16', 16)])
def test_mp_7(self, candidate: SupportedDType, qsim_bw: int):
""" Basic test that user request was applied to model correctly """
Expand Down Expand Up @@ -1590,4 +1595,40 @@ def forward(self, *inputs):
mp_configurator = MixedPrecisionConfigurator(sim)
mp_configurator.set_precision(torch.nn.Linear, 'int16', param={'weight': 'int16'})
with pytest.raises(RuntimeError):
mp_configurator.apply()
mp_configurator.apply()

def test_mp_46(self):
"""
Test symmetric settings
"""

model = SingleResidual()

torch.manual_seed(0)
input_tensor = torch.randn((1, 3, 32, 32))
sim = QuantizationSimModel(model, input_tensor, default_data_type=QuantizationDataType.float, default_output_bw=16, default_param_bw=16,
config_file=get_path_for_per_channel_config())

mp_configurator = MixedPrecisionConfigurator(sim)

mp_configurator.set_precision(torch.nn.Conv2d, 'int8', {'weight': 'int8'})
mp_configurator.apply()

sim.compute_encodings(lambda model, _: model(input_tensor), None)

for m in sim.model.modules():
if isinstance(m, torch.nn.Conv2d):
for q in m.input_quantizers:
if q:
assert isinstance(q, QuantizeDequantize)
assert q.symmetric == False
assert q.shape == ()
for q in m.output_quantizers:
if q:
assert isinstance(q, QuantizeDequantize)
assert q.symmetric == False
assert q.shape == ()

assert isinstance(m.param_quantizers['weight'], QuantizeDequantize)
assert m.param_quantizers['weight'].symmetric == True
assert m.param_quantizers['weight'].shape != ()

0 comments on commit 105a4cf

Please sign in to comment.