Skip to content

Commit 1667eb7

Browse files
committed
Revert "Directly tests export_ckpt function instead of using command_line_tests"
This reverts commit d4b01e6.
1 parent ba16743 commit 1667eb7

File tree

1 file changed

+25
-38
lines changed

1 file changed

+25
-38
lines changed

tests/test_bundle_ckpt_export.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from parameterized import parameterized
2020

2121
from monai.bundle import ConfigParser
22-
from monai.bundle.scripts import ckpt_export
2322
from monai.data import load_net_with_metadata
2423
from monai.networks import save_state
25-
from tests.utils import skip_if_windows
24+
from tests.utils import command_line_tests, skip_if_windows
2625

2726
TEST_CASE_1 = ["", ""]
2827

@@ -33,7 +32,6 @@
3332

3433
@skip_if_windows
3534
class TestCKPTExport(unittest.TestCase):
36-
3735
def setUp(self):
3836
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
3937
if not self.device:
@@ -52,6 +50,8 @@ def setUp(self):
5250
self.parser.export_config_file(config=self.def_args, filepath=self.def_args_file)
5351
self.parser.read_config(self.config_file)
5452
self.net = self.parser.get_parsed_content("network_def")
53+
self.cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file]
54+
self.cmd += ["--meta_file", self.meta_file, "--config_file", f"['{self.config_file}','{self.def_args_file}']", "--ckpt_file"]
5555

5656
def tearDown(self):
5757
if self.device is not None:
@@ -61,47 +61,34 @@ def tearDown(self):
6161
self.tempdir_obj.cleanup()
6262

6363
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
64-
def test_ckpt_export_default(self, key_in_ckpt, use_trace):
65-
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
66-
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")
67-
68-
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)
69-
ckpt_export(
70-
net_id="network_def",
71-
filepath=ts_file,
72-
meta_file=self.meta_file,
73-
config_file=self.config_file,
74-
ckpt_file=ckpt_file,
75-
key_in_ckpt=key_in_ckpt,
76-
args_file=self.def_args_file,
77-
use_trace=use_trace,
78-
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
79-
)
80-
self.assertTrue(os.path.exists(ts_file))
81-
82-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
83-
def test_ckpt_export(self, key_in_ckpt, use_trace):
84-
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
85-
ckpt_export(
86-
net_id="network_def",
87-
filepath=self.ts_file,
88-
meta_file=self.meta_file,
89-
config_file=[self.config_file, self.def_args_file],
90-
ckpt_file=self.ckpt_file,
91-
key_in_ckpt=key_in_ckpt,
92-
args_file=self.def_args_file,
93-
use_trace=use_trace,
94-
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
95-
)
64+
def test_export(self, key_in_ckpt, use_trace):
65+
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file) # noqa: E117
66+
full_cmd = self.cmd + [self.ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", self.def_args_file]
67+
if use_trace == "True":
68+
full_cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
69+
command_line_tests(full_cmd)
9670
self.assertTrue(os.path.exists(self.ts_file))
9771

98-
_, metadata, extra_files = load_net_with_metadata(
99-
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
100-
)
72+
_, metadata, extra_files = load_net_with_metadata(self.ts_file, more_extra_files=["inference.json", "def_args.json"])
10173
self.assertIn("schema", metadata)
10274
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
10375
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
10476

77+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
78+
def test_default_value(self, key_in_ckpt, use_trace):
79+
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
80+
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")
81+
82+
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)
83+
84+
# check with default value
85+
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
86+
cmd += ["--config_file", self.config_file, "--bundle_root", self.tempdir_obj.name]
87+
if use_trace == "True":
88+
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
89+
command_line_tests(cmd)
90+
self.assertTrue(os.path.exists(ts_file))
91+
10592

10693
if __name__ == "__main__":
10794
unittest.main()

0 commit comments

Comments
 (0)