|  | 
|  | 1 | +import pathlib | 
|  | 2 | +import tempfile | 
|  | 3 | + | 
|  | 4 | +from keras.src import backend | 
|  | 5 | +from keras.src import tree | 
|  | 6 | +from keras.src.export.export_utils import convert_spec_to_tensor | 
|  | 7 | +from keras.src.export.export_utils import get_input_signature | 
|  | 8 | +from keras.src.export.saved_model import export_saved_model | 
|  | 9 | +from keras.src.utils.module_utils import tensorflow as tf | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): | 
|  | 13 | +    """Export the model as a ONNX artifact for inference. | 
|  | 14 | +
 | 
|  | 15 | +    This method lets you export a model to a lightweight ONNX artifact | 
|  | 16 | +    that contains the model's forward pass only (its `call()` method) | 
|  | 17 | +    and can be served via e.g. ONNX Runtime. | 
|  | 18 | +
 | 
|  | 19 | +    The original code of the model (including any custom layers you may | 
|  | 20 | +    have used) is *no longer* necessary to reload the artifact -- it is | 
|  | 21 | +    entirely standalone. | 
|  | 22 | +
 | 
|  | 23 | +    Args: | 
|  | 24 | +        filepath: `str` or `pathlib.Path` object. The path to save the artifact. | 
|  | 25 | +        verbose: `bool`. Whether to print a message during export. Defaults to | 
|  | 26 | +            True`. | 
|  | 27 | +        input_signature: Optional. Specifies the shape and dtype of the model | 
|  | 28 | +            inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, | 
|  | 29 | +            `backend.KerasTensor`, or backend tensor. If not provided, it will | 
|  | 30 | +            be automatically computed. Defaults to `None`. | 
|  | 31 | +        **kwargs: Additional keyword arguments. | 
|  | 32 | +
 | 
|  | 33 | +    **Note:** This feature is currently supported only with TensorFlow, JAX and | 
|  | 34 | +    Torch backends. | 
|  | 35 | +
 | 
|  | 36 | +    **Note:** The dtype policy must be "float32" for the model. You can further | 
|  | 37 | +    optimize the ONNX artifact using the ONNX toolkit. Learn more here: | 
|  | 38 | +    [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/). | 
|  | 39 | +
 | 
|  | 40 | +    **Note:** The dynamic shape feature is not yet supported with Torch | 
|  | 41 | +    backend. As a result, you must fully define the shapes of the inputs using | 
|  | 42 | +    `input_signature`. If `input_signature` is not provided, all instances of | 
|  | 43 | +    `None` (such as the batch size) will be replaced with `1`. | 
|  | 44 | +
 | 
|  | 45 | +    Example: | 
|  | 46 | +
 | 
|  | 47 | +    ```python | 
|  | 48 | +    # Export the model as a ONNX artifact | 
|  | 49 | +    model.export("path/to/location", format="onnx") | 
|  | 50 | +
 | 
|  | 51 | +    # Load the artifact in a different process/environment | 
|  | 52 | +    ort_session = onnxruntime.InferenceSession("path/to/location") | 
|  | 53 | +    ort_inputs = { | 
|  | 54 | +        k.name: v for k, v in zip(ort_session.get_inputs(), input_data) | 
|  | 55 | +    } | 
|  | 56 | +    predictions = ort_session.run(None, ort_inputs) | 
|  | 57 | +    ``` | 
|  | 58 | +    """ | 
|  | 59 | +    if input_signature is None: | 
|  | 60 | +        input_signature = get_input_signature(model) | 
|  | 61 | +        if not input_signature or not model._called: | 
|  | 62 | +            raise ValueError( | 
|  | 63 | +                "The model provided has never called. " | 
|  | 64 | +                "It must be called at least once before export." | 
|  | 65 | +            ) | 
|  | 66 | + | 
|  | 67 | +    if backend.backend() in ("tensorflow", "jax"): | 
|  | 68 | +        working_dir = pathlib.Path(filepath).parent | 
|  | 69 | +        with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir: | 
|  | 70 | +            if backend.backend() == "jax": | 
|  | 71 | +                kwargs = _check_jax_kwargs(kwargs) | 
|  | 72 | +            export_saved_model( | 
|  | 73 | +                model, | 
|  | 74 | +                temp_dir, | 
|  | 75 | +                verbose, | 
|  | 76 | +                input_signature, | 
|  | 77 | +                **kwargs, | 
|  | 78 | +            ) | 
|  | 79 | +            saved_model_to_onnx(temp_dir, filepath, model.name) | 
|  | 80 | + | 
|  | 81 | +    elif backend.backend() == "torch": | 
|  | 82 | +        import torch | 
|  | 83 | + | 
|  | 84 | +        sample_inputs = tree.map_structure( | 
|  | 85 | +            lambda x: convert_spec_to_tensor(x, replace_none_number=1), | 
|  | 86 | +            input_signature, | 
|  | 87 | +        ) | 
|  | 88 | +        sample_inputs = tuple(sample_inputs) | 
|  | 89 | +        # TODO: Make dict model exportable. | 
|  | 90 | +        if any(isinstance(x, dict) for x in sample_inputs): | 
|  | 91 | +            raise ValueError( | 
|  | 92 | +                "Currently, `export_onnx` in the torch backend doesn't support " | 
|  | 93 | +                "dictionaries as inputs." | 
|  | 94 | +            ) | 
|  | 95 | + | 
|  | 96 | +        # Convert to ONNX using TorchScript-based ONNX Exporter. | 
|  | 97 | +        # TODO: Use TorchDynamo-based ONNX Exporter once | 
|  | 98 | +        # `torch.onnx.dynamo_export()` supports Keras models. | 
|  | 99 | +        torch.onnx.export(model, sample_inputs, filepath, verbose=verbose) | 
|  | 100 | +    else: | 
|  | 101 | +        raise NotImplementedError( | 
|  | 102 | +            "`export_onnx` is only compatible with TensorFlow, JAX and " | 
|  | 103 | +            "Torch backends." | 
|  | 104 | +        ) | 
|  | 105 | + | 
|  | 106 | + | 
|  | 107 | +def _check_jax_kwargs(kwargs): | 
|  | 108 | +    kwargs = kwargs.copy() | 
|  | 109 | +    if "is_static" not in kwargs: | 
|  | 110 | +        kwargs["is_static"] = True | 
|  | 111 | +    if "jax2tf_kwargs" not in kwargs: | 
|  | 112 | +        # TODO: These options will be deprecated in JAX. We need to | 
|  | 113 | +        # find another way to export ONNX. | 
|  | 114 | +        kwargs["jax2tf_kwargs"] = { | 
|  | 115 | +            "enable_xla": False, | 
|  | 116 | +            "native_serialization": False, | 
|  | 117 | +        } | 
|  | 118 | +    if kwargs["is_static"] is not True: | 
|  | 119 | +        raise ValueError( | 
|  | 120 | +            "`is_static` must be `True` in `kwargs` when using the jax " | 
|  | 121 | +            "backend." | 
|  | 122 | +        ) | 
|  | 123 | +    if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: | 
|  | 124 | +        raise ValueError( | 
|  | 125 | +            "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " | 
|  | 126 | +            "when using the jax backend." | 
|  | 127 | +        ) | 
|  | 128 | +    if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: | 
|  | 129 | +        raise ValueError( | 
|  | 130 | +            "`native_serialization` must be `False` in " | 
|  | 131 | +            "`kwargs['jax2tf_kwargs']` when using the jax backend." | 
|  | 132 | +        ) | 
|  | 133 | +    return kwargs | 
|  | 134 | + | 
|  | 135 | + | 
|  | 136 | +def saved_model_to_onnx(saved_model_dir, filepath, name): | 
|  | 137 | +    from keras.src.utils.module_utils import tf2onnx | 
|  | 138 | + | 
|  | 139 | +    # Convert to ONNX using `tf2onnx` library. | 
|  | 140 | +    (graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = ( | 
|  | 141 | +        tf2onnx.tf_loader.from_saved_model( | 
|  | 142 | +            saved_model_dir, | 
|  | 143 | +            None, | 
|  | 144 | +            None, | 
|  | 145 | +            return_initialized_tables=True, | 
|  | 146 | +            return_tensors_to_rename=True, | 
|  | 147 | +        ) | 
|  | 148 | +    ) | 
|  | 149 | + | 
|  | 150 | +    with tf.device("/cpu:0"): | 
|  | 151 | +        _ = tf2onnx.convert._convert_common( | 
|  | 152 | +            graph_def, | 
|  | 153 | +            name=name, | 
|  | 154 | +            target=[], | 
|  | 155 | +            custom_op_handlers={}, | 
|  | 156 | +            extra_opset=[], | 
|  | 157 | +            input_names=inputs, | 
|  | 158 | +            output_names=outputs, | 
|  | 159 | +            tensors_to_rename=tensors_to_rename, | 
|  | 160 | +            initialized_tables=initialized_tables, | 
|  | 161 | +            output_path=filepath, | 
|  | 162 | +        ) | 
0 commit comments