1919from parameterized import parameterized
2020
2121from monai .bundle import ConfigParser
22- from monai .bundle .scripts import ckpt_export
2322from monai .data import load_net_with_metadata
2423from monai .networks import save_state
25- from tests .utils import skip_if_windows
24+ from tests .utils import command_line_tests , skip_if_windows
2625
2726TEST_CASE_1 = ["" , "" ]
2827
3332
3433@skip_if_windows
3534class TestCKPTExport (unittest .TestCase ):
36-
3735 def setUp (self ):
3836 self .device = os .environ .get ("CUDA_VISIBLE_DEVICES" )
3937 if not self .device :
@@ -52,6 +50,8 @@ def setUp(self):
5250 self .parser .export_config_file (config = self .def_args , filepath = self .def_args_file )
5351 self .parser .read_config (self .config_file )
5452 self .net = self .parser .get_parsed_content ("network_def" )
53+ self .cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "network_def" , "--filepath" , self .ts_file ]
54+ self .cmd += ["--meta_file" , self .meta_file , "--config_file" , f"['{ self .config_file } ','{ self .def_args_file } ']" , "--ckpt_file" ]
5555
5656 def tearDown (self ):
5757 if self .device is not None :
@@ -61,47 +61,34 @@ def tearDown(self):
6161 self .tempdir_obj .cleanup ()
6262
6363 @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
64- def test_ckpt_export_default (self , key_in_ckpt , use_trace ):
65- ckpt_file = os .path .join (self .tempdir_obj .name , "models/model.pt" )
66- ts_file = os .path .join (self .tempdir_obj .name , "models/model.ts" )
67-
68- save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = ckpt_file )
69- ckpt_export (
70- net_id = "network_def" ,
71- filepath = ts_file ,
72- meta_file = self .meta_file ,
73- config_file = self .config_file ,
74- ckpt_file = ckpt_file ,
75- key_in_ckpt = key_in_ckpt ,
76- args_file = self .def_args_file ,
77- use_trace = use_trace ,
78- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
79- )
80- self .assertTrue (os .path .exists (ts_file ))
81-
82- @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
83- def test_ckpt_export (self , key_in_ckpt , use_trace ):
84- save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
85- ckpt_export (
86- net_id = "network_def" ,
87- filepath = self .ts_file ,
88- meta_file = self .meta_file ,
89- config_file = [self .config_file , self .def_args_file ],
90- ckpt_file = self .ckpt_file ,
91- key_in_ckpt = key_in_ckpt ,
92- args_file = self .def_args_file ,
93- use_trace = use_trace ,
94- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
95- )
64+ def test_export (self , key_in_ckpt , use_trace ):
65+ save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file ) # noqa: E117
66+ full_cmd = self .cmd + [self .ckpt_file , "--key_in_ckpt" , key_in_ckpt , "--args_file" , self .def_args_file ]
67+ if use_trace == "True" :
68+ full_cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
69+ command_line_tests (full_cmd )
9670 self .assertTrue (os .path .exists (self .ts_file ))
9771
98- _ , metadata , extra_files = load_net_with_metadata (
99- self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
100- )
72+ _ , metadata , extra_files = load_net_with_metadata (self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ])
10173 self .assertIn ("schema" , metadata )
10274 self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
10375 self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
10476
77+ @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
78+ def test_default_value (self , key_in_ckpt , use_trace ):
79+ ckpt_file = os .path .join (self .tempdir_obj .name , "models/model.pt" )
80+ ts_file = os .path .join (self .tempdir_obj .name , "models/model.ts" )
81+
82+ save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = ckpt_file )
83+
84+ # check with default value
85+ cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "--key_in_ckpt" , key_in_ckpt ]
86+ cmd += ["--config_file" , self .config_file , "--bundle_root" , self .tempdir_obj .name ]
87+ if use_trace == "True" :
88+ cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
89+ command_line_tests (cmd )
90+ self .assertTrue (os .path .exists (ts_file ))
91+
10592
10693if __name__ == "__main__" :
10794 unittest .main ()
0 commit comments