Skip to content

Commit 802cc84

Browse files
author
root
committed
add calibration and quantization
Signed-off-by: root <[email protected]>
1 parent 0bb20a8 commit 802cc84

File tree

4 files changed

+153
-1
lines changed

4 files changed

+153
-1
lines changed

monai/engines/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Trainer(Workflow):
4646
4747
"""
4848

49-
def run(self) -> None: # type: ignore[override]
49+
def run(self, *args) -> None: # type: ignore[override]
5050
"""
5151
Execute training based on Ignite Engine.
5252
If call this function multiple times, it will continuously run from the previous state.

monai/handlers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler
3030
from .metrics_saver import MetricsSaver
3131
from .mlflow_handler import MLFlowHandler
32+
from .model_quantizer import ModelQuantizer
33+
from .model_calibrator import ModelCalibrater
3234
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
3335
from .panoptic_quality import PanopticQuality
3436
from .parameter_scheduler import ParamSchedulerHandler

monai/handlers/model_calibrator.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import TYPE_CHECKING
15+
16+
import torch
17+
import modelopt.torch.quantization as mtq
18+
from functools import partial
19+
from monai.utils import IgniteInfo, min_version, optional_import
20+
21+
22+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
23+
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
24+
if TYPE_CHECKING:
25+
from ignite.engine import Engine
26+
else:
27+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
28+
29+
30+
class ModelCalibrater:
31+
"""
32+
Model quantizer is for model quantization. It takes a model as input and convert it to a quantized
33+
model.
34+
35+
Args:
36+
model: the model to be quantized.
37+
example_inputs: the example inputs for the model quantization. examples::
38+
(torch.randn(256,256,256),)
39+
config: the calibration config.
40+
41+
"""
42+
43+
def __init__(
44+
self,
45+
model: torch.nn.Module,
46+
export_path: str,
47+
config: dict= mtq.INT8_SMOOTHQUANT_CFG,
48+
49+
) -> None:
50+
self.model = model
51+
self.export_path = export_path
52+
self.config = config
53+
54+
def attach(self, engine: Engine) -> None:
55+
"""
56+
Args:
57+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
58+
"""
59+
engine.add_event_handler(Events.STARTED, self)
60+
61+
@staticmethod
62+
def _model_wrapper(engine, model):
63+
engine.run()
64+
65+
def __call__(self, engine) -> None:
66+
quant_fun = partial(self._model_wrapper, engine)
67+
model = mtq.quantize(self.model, self.config, quant_fun)
68+
torch.save(self.model.state_dict(), self.export_path)

monai/handlers/model_quantizer.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import logging
15+
import warnings
16+
from types import MethodType
17+
from typing import TYPE_CHECKING, Sequence
18+
19+
import torch
20+
21+
from monai.networks.utils import copy_model_state
22+
from monai.utils import IgniteInfo, min_version, optional_import
23+
from torch.ao.quantization.quantizer import Quantizer
24+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
25+
XNNPACKQuantizer,
26+
get_symmetric_quantization_config,
27+
)
28+
from torch.ao.quantization.quantize_pt2e import (
29+
prepare_qat_pt2e,
30+
convert_pt2e,
31+
)
32+
33+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
34+
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
35+
if TYPE_CHECKING:
36+
from ignite.engine import Engine
37+
else:
38+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
39+
40+
41+
class ModelQuantizer:
42+
"""
43+
Model quantizer is for model quantization. It takes a model as input and convert it to a quantized
44+
model.
45+
46+
Args:
47+
model: the model to be quantized.
48+
example_inputs: the example inputs for the model quantization. examples::
49+
(torch.randn(256,256,256),)
50+
quantizer: quantizer for the quantization job.
51+
52+
"""
53+
54+
def __init__(
55+
self,
56+
model: torch.nn.Module,
57+
example_inputs: Sequence,
58+
export_path: str,
59+
quantizer: Quantizer | None = None,
60+
61+
) -> None:
62+
self.model = model
63+
self.example_inputs = example_inputs
64+
self.export_path = export_path
65+
self.quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer
66+
67+
def attach(self, engine: Engine) -> None:
68+
"""
69+
Args:
70+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
71+
"""
72+
engine.add_event_handler(Events.STARTED, self.start)
73+
engine.add_event_handler(Events.ITERATION_COMPLETED, self.epoch)
74+
75+
def start(self) -> None:
76+
self.model = torch.export.export_for_training(self.model, self.example_inputs).module()
77+
self.model = prepare_qat_pt2e(self.model, self.quantizer)
78+
self.model.train = MethodType(torch.ao.quantization.move_exported_model_to_train, self.model)
79+
self.model.eval = MethodType(torch.ao.quantization.move_exported_model_to_eval, self.model)
80+
81+
def epoch(self) -> None:
82+
torch.save(self.model.state_dict(), self.export_path)

0 commit comments

Comments
 (0)