Skip to content

Commit 3c252a8

Browse files
Fix the logging of a nested dictionary metric in MLflow (#8169)
Fix Project-MONAI/model-zoo#697 ### Description Flatten the metric dict when the metric is a nested dictionary. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 82298ad commit 3c252a8

File tree

6 files changed

+34
-13
lines changed

6 files changed

+34
-13
lines changed

monai/handlers/mlflow_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.utils.data import Dataset
2323

2424
from monai.apps.utils import get_logger
25-
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import
25+
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import
2626

2727
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
2828
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
@@ -303,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None
303303

304304
run_id = self.cur_run.info.run_id
305305
timestamp = int(time.time() * 1000)
306-
metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
306+
metrics_arr = [
307+
mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items()
308+
]
307309
self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])
308310

309311
def _parse_artifacts(self):

monai/handlers/stats_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from monai.apps import get_logger
22-
from monai.utils import IgniteInfo, is_scalar, min_version, optional_import
22+
from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import
2323

2424
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
2525
if TYPE_CHECKING:
@@ -211,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None:
211211
212212
"""
213213
current_epoch = self.global_epoch_transform(engine.state.epoch)
214-
215-
prints_dict = engine.state.metrics
214+
prints_dict = flatten_dict(engine.state.metrics)
216215
if prints_dict is not None and len(prints_dict) > 0:
217216
out_str = f"Epoch[{current_epoch}] Metrics -- "
218217
for name in sorted(prints_dict):

monai/networks/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import io
1818
import re
19+
import tempfile
1920
import warnings
2021
from collections import OrderedDict
2122
from collections.abc import Callable, Mapping, Sequence
@@ -688,27 +689,25 @@ def convert_to_onnx(
688689
onnx_inputs = (inputs,)
689690
else:
690691
onnx_inputs = tuple(inputs)
691-
692+
temp_file = None
692693
if filename is None:
693-
f = io.BytesIO()
694+
temp_file = tempfile.NamedTemporaryFile()
695+
f = temp_file.name
694696
else:
695697
f = filename
696698

697699
torch.onnx.export(
698700
mode_to_export,
699701
onnx_inputs,
700-
f=f, # type: ignore[arg-type]
702+
f=f,
701703
input_names=input_names,
702704
output_names=output_names,
703705
dynamic_axes=dynamic_axes,
704706
opset_version=opset_version,
705707
do_constant_folding=do_constant_folding,
706708
**torch_versioned_kwargs,
707709
)
708-
if filename is None:
709-
onnx_model = onnx.load_model_from_string(f.getvalue())
710-
else:
711-
onnx_model = onnx.load(filename)
710+
onnx_model = onnx.load(f)
712711

713712
if do_constant_folding and polygraphy_imported:
714713
from polygraphy.backend.onnx.loader import fold_constants

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
ensure_tuple_size,
7979
fall_back_tuple,
8080
first,
81+
flatten_dict,
8182
get_seed,
8283
has_option,
8384
is_immutable,

monai/utils/misc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,16 @@ def unsqueeze_right(arr: NT, ndim: int) -> NT:
916916
def unsqueeze_left(arr: NT, ndim: int) -> NT:
917917
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
918918
return arr[(None,) * (ndim - arr.ndim)]
919+
920+
921+
def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]:
922+
"""
923+
Flatten the nested dictionary to a flat dictionary.
924+
"""
925+
result = {}
926+
for key, value in metrics.items():
927+
if isinstance(value, dict):
928+
result.update(flatten_dict(value))
929+
else:
930+
result[key] = value
931+
return result

tests/test_handler_mlflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def _train_func(engine, batch):
122122
def _update_metric(engine):
123123
current_metric = engine.state.metrics.get("acc", 0.1)
124124
engine.state.metrics["acc"] = current_metric + 0.1
125+
# log nested metrics
126+
engine.state.metrics["acc_per_label"] = {
127+
"label_0": current_metric + 0.1,
128+
"label_1": current_metric + 0.2,
129+
}
125130
engine.state.test = current_metric
126131

127132
# set up testing handler
@@ -138,10 +143,12 @@ def _update_metric(engine):
138143
state_attributes=["test"],
139144
experiment_param=experiment_param,
140145
artifacts=[artifact_path],
141-
close_on_complete=True,
146+
close_on_complete=False,
142147
)
143148
handler.attach(engine)
144149
engine.run(range(3), max_epochs=2)
150+
cur_run = handler.client.get_run(handler.cur_run.info.run_id)
151+
self.assertTrue("label_0" in cur_run.data.metrics.keys())
145152
handler.close()
146153
# check logging output
147154
self.assertTrue(len(glob.glob(test_path)) > 0)

0 commit comments

Comments
 (0)