1111
1212from __future__ import annotations
1313
14- import logging
15- import warnings
1614from types import MethodType
17- from typing import TYPE_CHECKING , Sequence
15+ from typing import TYPE_CHECKING
16+ from collections .abc import Sequence
1817
1918import torch
2019
21- from monai .networks .utils import copy_model_state
2220from monai .utils import IgniteInfo , min_version , optional_import
2321from torch .ao .quantization .quantizer import Quantizer
2422from torch .ao .quantization .quantizer .xnnpack_quantizer import (
2725)
2826from torch .ao .quantization .quantize_pt2e import (
2927 prepare_qat_pt2e ,
30- convert_pt2e ,
3128)
3229
3330Events , _ = optional_import ("ignite.engine" , IgniteInfo .OPT_IMPORT_VERSION , min_version , "Events" )
@@ -57,7 +54,7 @@ def __init__(
5754 example_inputs : Sequence ,
5855 export_path : str ,
5956 quantizer : Quantizer | None = None ,
60-
57+
6158 ) -> None :
6259 self .model = model
6360 self .example_inputs = example_inputs
@@ -77,6 +74,6 @@ def start(self) -> None:
7774 self .model = prepare_qat_pt2e (self .model , self .quantizer )
7875 self .model .train = MethodType (torch .ao .quantization .move_exported_model_to_train , self .model )
7976 self .model .eval = MethodType (torch .ao .quantization .move_exported_model_to_eval , self .model )
80-
77+
8178 def epoch (self ) -> None :
8279 torch .save (self .model .state_dict (), self .export_path )
0 commit comments