From d5908da9c32194d8362bb5170fca8d0f24f7fe3e Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:21:14 -0700 Subject: [PATCH] [docs] Fix fp16 mixed precision example to reflect correct input variable names (#22250) ### Description Correct variable name from `test_data` to `feed_dict` to fix example code in mixed precision example docs. ### Motivation and Context Fixes #21822 --- docs/performance/model-optimizations/float16.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/performance/model-optimizations/float16.md b/docs/performance/model-optimizations/float16.md index 972f5fe516f6b..a0335ccbac70f 100644 --- a/docs/performance/model-optimizations/float16.md +++ b/docs/performance/model-optimizations/float16.md @@ -62,7 +62,9 @@ from onnxconverter_common import auto_mixed_precision import onnx model = onnx.load("path/to/model.onnx") -model_fp16 = auto_convert_mixed_precision(model, test_data, rtol=0.01, atol=0.001, keep_io_types=True) +# Assuming x is the input to the model +feed_dict = {'input': x.numpy()} +model_fp16 = auto_convert_mixed_precision(model, feed_dict, rtol=0.01, atol=0.001, keep_io_types=True) onnx.save(model_fp16, "path/to/model_fp16.onnx") ``` @@ -73,6 +75,7 @@ auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol ``` - `model`: The ONNX model to convert. +- `feed_dict`: Test data used to measure the accuracy of the model during conversion. Format is similar to InferenceSession.run (map of input names to values) - `validate_fn`: A function accepting two lists of numpy arrays (the outputs of the float32 model and the mixed-precision model, respectively) that returns `True` if the results are sufficiently close and `False` otherwise. Can be used instead of or in addition to `rtol` and `atol`. - `rtol`, `atol`: Absolute and relative tolerances used for validation. See [numpy.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html) for more information. - `keep_io_types`: Whether model inputs/outputs should be left as float32.