|
15 | 15 | import os |
16 | 16 | import tempfile |
17 | 17 | import unittest |
18 | | -from pathlib import Path |
19 | 18 |
|
20 | 19 | from parameterized import parameterized |
21 | 20 |
|
22 | 21 | from monai.bundle import ConfigParser |
23 | 22 | from monai.data import load_net_with_metadata |
24 | 23 | from monai.networks import save_state |
25 | | -from tests.util import command_line_tests, skip_if_windows |
| 24 | +from tests.utils import command_line_tests, skip_if_windows |
26 | 25 |
|
27 | 26 | TEST_CASE_1 = ["", ""] |
| 27 | + |
28 | 28 | TEST_CASE_2 = ["model", ""] |
| 29 | + |
29 | 30 | TEST_CASE_3 = ["model", "True"] |
30 | 31 |
|
31 | 32 |
|
32 | 33 | @skip_if_windows |
33 | 34 | class TestCKPTExport(unittest.TestCase): |
| 35 | + |
34 | 36 | def setUp(self): |
35 | 37 | self.device = os.environ.get("CUDA_VISIBLE_DEVICES") |
36 | 38 | if not self.device: |
37 | 39 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # default |
38 | | - module_path = Path(__file__).resolve().parents[1].as_posix() |
39 | | - self.meta_file = os.path.join(module_path, "testing_data", "metadata.json") |
40 | | - self.config_file = os.path.join(module_path, "testing_data", "inference.json") |
41 | | - self.tempdir_obj = tempfile.TemporaryDirectory() |
42 | | - tempdir = self.tempdir_obj.name |
43 | | - self.def_args = {"meta_file": "will be replaced by `meta_file` arg"} |
44 | | - self.def_args_file = os.path.join(tempdir, "def_args.yaml") |
45 | | - |
46 | | - self.ckpt_file = os.path.join(tempdir, "model.pt") |
47 | | - self.ts_file = os.path.join(tempdir, "model.ts") |
48 | | - |
49 | | - self.parser = ConfigParser() |
50 | | - self.parser.export_config_file(config=self.def_args, filepath=self.def_args_file) |
51 | | - self.parser.read_config(self.config_file) |
52 | | - self.net = self.parser.get_parsed_content("network_def") |
53 | | - self.cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file] |
54 | | - self.cmd += [ |
55 | | - "--meta_file", |
56 | | - self.meta_file, |
57 | | - "--config_file", |
58 | | - f"['{self.config_file}','{self.def_args_file}']", |
59 | | - "--ckpt_file", |
60 | | - ] |
61 | 40 |
|
62 | 41 | def tearDown(self): |
63 | 42 | if self.device is not None: |
64 | 43 | os.environ["CUDA_VISIBLE_DEVICES"] = self.device |
65 | 44 | else: |
66 | 45 | del os.environ["CUDA_VISIBLE_DEVICES"] # previously unset |
67 | | - self.tempdir_obj.cleanup() |
68 | 46 |
|
69 | 47 | @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) |
70 | 48 | def test_export(self, key_in_ckpt, use_trace): |
71 | | - save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file) |
72 | | - full_cmd = self.cmd + [self.ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", self.def_args_file] |
73 | | - if use_trace == "True": |
74 | | - full_cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"] |
75 | | - command_line_tests(full_cmd) |
76 | | - self.assertTrue(os.path.exists(self.ts_file)) |
77 | | - |
78 | | - _, metadata, extra_files = load_net_with_metadata( |
79 | | - self.ts_file, more_extra_files=["inference.json", "def_args.json"] |
80 | | - ) |
81 | | - self.assertIn("schema", metadata) |
82 | | - self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) |
83 | | - self.assertIn("network_def", json.loads(extra_files["inference.json"])) |
| 49 | + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") |
| 50 | + config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") |
| 51 | + with tempfile.TemporaryDirectory() as tempdir: |
| 52 | + def_args = {"meta_file": "will be replaced by `meta_file` arg"} |
| 53 | + def_args_file = os.path.join(tempdir, "def_args.yaml") |
| 54 | + |
| 55 | + ckpt_file = os.path.join(tempdir, "model.pt") |
| 56 | + ts_file = os.path.join(tempdir, "model.ts") |
| 57 | + |
| 58 | + parser = ConfigParser() |
| 59 | + parser.export_config_file(config=def_args, filepath=def_args_file) |
| 60 | + parser.read_config(config_file) |
| 61 | + net = parser.get_parsed_content("network_def") |
| 62 | + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) |
| 63 | + |
| 64 | + cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] |
| 65 | + cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"] |
| 66 | + cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] |
| 67 | + if use_trace == "True": |
| 68 | + cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"] |
| 69 | + command_line_tests(cmd) |
| 70 | + self.assertTrue(os.path.exists(ts_file)) |
| 71 | + |
| 72 | + _, metadata, extra_files = load_net_with_metadata( |
| 73 | + ts_file, more_extra_files=["inference.json", "def_args.json"] |
| 74 | + ) |
| 75 | + self.assertIn("schema", metadata) |
| 76 | + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) |
| 77 | + self.assertIn("network_def", json.loads(extra_files["inference.json"])) |
84 | 78 |
|
85 | 79 | @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) |
86 | 80 | def test_default_value(self, key_in_ckpt, use_trace): |
87 | | - ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt") |
88 | | - ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts") |
89 | | - |
90 | | - save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file) |
91 | | - |
92 | | - # check with default value |
93 | | - cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt] |
94 | | - cmd += ["--config_file", self.config_file, "--bundle_root", self.tempdir_obj.name] |
95 | | - if use_trace == "True": |
96 | | - cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"] |
97 | | - command_line_tests(cmd) |
98 | | - self.assertTrue(os.path.exists(ts_file)) |
| 81 | + config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") |
| 82 | + with tempfile.TemporaryDirectory() as tempdir: |
| 83 | + def_args = {"meta_file": "will be replaced by `meta_file` arg"} |
| 84 | + def_args_file = os.path.join(tempdir, "def_args.yaml") |
| 85 | + ckpt_file = os.path.join(tempdir, "models/model.pt") |
| 86 | + ts_file = os.path.join(tempdir, "models/model.ts") |
| 87 | + |
| 88 | + parser = ConfigParser() |
| 89 | + parser.export_config_file(config=def_args, filepath=def_args_file) |
| 90 | + parser.read_config(config_file) |
| 91 | + net = parser.get_parsed_content("network_def") |
| 92 | + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) |
| 93 | + |
| 94 | + # check with default value |
| 95 | + cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt] |
| 96 | + cmd += ["--config_file", config_file, "--bundle_root", tempdir] |
| 97 | + if use_trace == "True": |
| 98 | + cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"] |
| 99 | + command_line_tests(cmd) |
| 100 | + self.assertTrue(os.path.exists(ts_file)) |
99 | 101 |
|
100 | 102 |
|
101 | 103 | if __name__ == "__main__": |
|
0 commit comments