diff --git a/docs/performance/model-optimizations/float16.md b/docs/performance/model-optimizations/float16.md index 972f5fe516f6b..a0335ccbac70f 100644 --- a/docs/performance/model-optimizations/float16.md +++ b/docs/performance/model-optimizations/float16.md @@ -62,7 +62,9 @@ from onnxconverter_common import auto_mixed_precision import onnx model = onnx.load("path/to/model.onnx") -model_fp16 = auto_convert_mixed_precision(model, test_data, rtol=0.01, atol=0.001, keep_io_types=True) +# Assuming x is the input to the model +feed_dict = {'input': x.numpy()} +model_fp16 = auto_convert_mixed_precision(model, feed_dict, rtol=0.01, atol=0.001, keep_io_types=True) onnx.save(model_fp16, "path/to/model_fp16.onnx") ``` @@ -73,6 +75,7 @@ auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol ``` - `model`: The ONNX model to convert. +- `feed_dict`: Test data used to measure the accuracy of the model during conversion. Format is similar to InferenceSession.run (map of input names to values) - `validate_fn`: A function accepting two lists of numpy arrays (the outputs of the float32 model and the mixed-precision model, respectively) that returns `True` if the results are sufficiently close and `False` otherwise. Can be used instead of or in addition to `rtol` and `atol`. - `rtol`, `atol`: Absolute and relative tolerances used for validation. See [numpy.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html) for more information. - `keep_io_types`: Whether model inputs/outputs should be left as float32.