Skip to content

Commit 88d855c

Browse files
committed
fix format
Signed-off-by: binliu <[email protected]>
1 parent 6ee190e commit 88d855c

File tree

3 files changed

+14
-28
lines changed

3 files changed

+14
-28
lines changed

monai/handlers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler
3030
from .metrics_saver import MetricsSaver
3131
from .mlflow_handler import MLFlowHandler
32-
from .model_quantizer import ModelQuantizer
3332
from .model_calibrator import ModelCalibrater
33+
from .model_quantizer import ModelQuantizer
3434
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
3535
from .panoptic_quality import PanopticQuality
3636
from .parameter_scheduler import ParamSchedulerHandler

monai/handlers/model_calibrator.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
from __future__ import annotations
1313

14+
from functools import partial
1415
from typing import TYPE_CHECKING
1516

16-
import torch
1717
import modelopt.torch.quantization as mtq
18-
from functools import partial
19-
from monai.utils import IgniteInfo, min_version, optional_import
18+
import torch
2019

20+
from monai.utils import IgniteInfo, min_version, optional_import
2121

2222
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
2323
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
@@ -40,13 +40,7 @@ class ModelCalibrater:
4040
4141
"""
4242

43-
def __init__(
44-
self,
45-
model: torch.nn.Module,
46-
export_path: str,
47-
config: dict= mtq.INT8_SMOOTHQUANT_CFG,
48-
49-
) -> None:
43+
def __init__(self, model: torch.nn.Module, export_path: str, config: dict = mtq.INT8_SMOOTHQUANT_CFG) -> None:
5044
self.model = model
5145
self.export_path = export_path
5246
self.config = config
@@ -65,4 +59,4 @@ def _model_wrapper(engine, model):
6559
def __call__(self, engine) -> None:
6660
quant_fun = partial(self._model_wrapper, engine)
6761
model = mtq.quantize(self.model, self.config, quant_fun)
68-
torch.save(self.model.state_dict(), self.export_path)
62+
torch.save(model.state_dict(), self.export_path)

monai/handlers/model_quantizer.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,16 @@
1111

1212
from __future__ import annotations
1313

14+
from collections.abc import Sequence
1415
from types import MethodType
1516
from typing import TYPE_CHECKING
16-
from collections.abc import Sequence
1717

1818
import torch
19+
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
20+
from torch.ao.quantization.quantizer import Quantizer
21+
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
1922

2023
from monai.utils import IgniteInfo, min_version, optional_import
21-
from torch.ao.quantization.quantizer import Quantizer
22-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
23-
XNNPACKQuantizer,
24-
get_symmetric_quantization_config,
25-
)
26-
from torch.ao.quantization.quantize_pt2e import (
27-
prepare_qat_pt2e,
28-
)
2924

3025
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
3126
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
@@ -49,17 +44,14 @@ class ModelQuantizer:
4944
"""
5045

5146
def __init__(
52-
self,
53-
model: torch.nn.Module,
54-
example_inputs: Sequence,
55-
export_path: str,
56-
quantizer: Quantizer | None = None,
57-
47+
self, model: torch.nn.Module, example_inputs: Sequence, export_path: str, quantizer: Quantizer | None = None
5848
) -> None:
5949
self.model = model
6050
self.example_inputs = example_inputs
6151
self.export_path = export_path
62-
self.quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer
52+
self.quantizer = (
53+
XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer
54+
)
6355

6456
def attach(self, engine: Engine) -> None:
6557
"""

0 commit comments

Comments
 (0)