From 0d26c624cb75d38506928bf575b8b35113a609e4 Mon Sep 17 00:00:00 2001 From: tanayvarshney Date: Wed, 1 Feb 2023 14:33:34 -0800 Subject: [PATCH] added a jax example --- Quick_Deploy/JAX/README.md | 89 +++++++++++++++++++ Quick_Deploy/JAX/client.py | 63 +++++++++++++ .../JAX/model_repository/resnet50/1/model.py | 58 ++++++++++++ .../model_repository/resnet50/config.pbtxt | 50 +++++++++++ Quick_Deploy/ONNX/README.md | 2 +- 5 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 Quick_Deploy/JAX/README.md create mode 100644 Quick_Deploy/JAX/client.py create mode 100644 Quick_Deploy/JAX/model_repository/resnet50/1/model.py create mode 100644 Quick_Deploy/JAX/model_repository/resnet50/config.pbtxt diff --git a/Quick_Deploy/JAX/README.md b/Quick_Deploy/JAX/README.md new file mode 100644 index 00000000..3c6be7b6 --- /dev/null +++ b/Quick_Deploy/JAX/README.md @@ -0,0 +1,89 @@ + + +# Deploying a JAX Model + +This README showcases how to deploy a simple ResNet model on Triton Inference Server. While Triton doesn't yet have a dedicated JAX backend, JAX/Flax models can be deployed using [Python Backend](https://github.com/triton-inference-server/python_backend). If you are new to Triton, it is recommended to watch this [getting started video](https://www.youtube.com/watch?v=NQDtfSi5QF4) and review [Part 1](https://github.com/triton-inference-server/tutorials/tree/main/Conceptual_Guide/Part_1-model_deployment) of the conceptual guide before proceeding. For the purposes of demonstration, we are using a pre-trained model provided by [flaxmodels](https://github.com/matthias-wright/flaxmodels). + +## Step 1: Set Up Triton Inference Server + +To use Triton, we need to build a model repository. The structure of the repository as follows: +``` +model_repository +| ++-- resnet50 + | + +-- config.pbtxt + +-- 1 + | + +-- model.py +``` +For this example, we have pre-built the model repository. Next, we install the required dependencies and launch the Triton Inference Server. + +``` +# Replace the yy.mm in the image name with the release year and month +# of the Triton version needed, eg. 22.12 +docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:-py3 bash + +pip install --upgrade pip +pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git + +``` + +## Step 2: Using a Triton Client to Query the Server + +Let's breakdown the client application. First, we setup a connection with the Triton Inference Server. +``` +client = httpclient.InferenceServerClient(url="localhost:8000") +``` +Then we set the input and output arrays. +``` +# Set Inputs +input_tensors = [ + httpclient.InferInput("image", image.shape, datatype="FP32") +] +input_tensors[0].set_data_from_numpy(image) + +# Set outputs +outputs = [ + httpclient.InferRequestedOutput("fc_out") +] +``` +Lastly, we query send a request to the Triton Inference Server. + +``` +# Query +query_response = client.infer(model_name="resnet50", + inputs=input_tensors, + outputs=outputs) + +# Output +out = query_response.as_numpy("fc_out") +``` + diff --git a/Quick_Deploy/JAX/client.py b/Quick_Deploy/JAX/client.py new file mode 100644 index 00000000..f33f9710 --- /dev/null +++ b/Quick_Deploy/JAX/client.py @@ -0,0 +1,63 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +from tritonclient.utils import * +from PIL import Image +import tritonclient.http as httpclient +import requests + + +def main(): + client = httpclient.InferenceServerClient(url="localhost:8000") + + # Inputs + url = "http://images.cocodataset.org/val2017/000000161642.jpg" + image = np.asarray(Image.open(requests.get(url, stream=True).raw)).astype(np.float32) + image = np.expand_dims(image, axis=0) + + # Set Inputs + input_tensors = [ + httpclient.InferInput("image", image.shape, datatype="FP32") + ] + input_tensors[0].set_data_from_numpy(image) + + # Set outputs + outputs = [ + httpclient.InferRequestedOutput("fc_out") + ] + + # Query + query_response = client.infer(model_name="resnet50", + inputs=input_tensors, + outputs=outputs) + + # Output + out = query_response.as_numpy("fc_out") + print(out.shape) + +if __name__ == "__main__": + main() diff --git a/Quick_Deploy/JAX/model_repository/resnet50/1/model.py b/Quick_Deploy/JAX/model_repository/resnet50/1/model.py new file mode 100644 index 00000000..40dc8bd1 --- /dev/null +++ b/Quick_Deploy/JAX/model_repository/resnet50/1/model.py @@ -0,0 +1,58 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import triton_python_backend_utils as pb_utils +import jax +import jax.numpy as jnp +import flaxmodels as fm + +import numpy as np +from flax.jax_utils import replicate + +class TritonPythonModel: + + def initialize(self, args): + self.key = jax.random.PRNGKey(0) + self.resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') + + + def execute(self, requests): + responses = [] + for request in requests: + inp = pb_utils.get_input_tensor_by_name(request, "image") + input_image = inp.as_numpy() + + params = self.resnet18.init(self.key, input_image) + out = self.resnet18.apply(params, input_image, train=False) + + inference_response = pb_utils.InferenceResponse(output_tensors=[ + pb_utils.Tensor( + "fc_out", + np.array(out), + ) + ]) + responses.append(inference_response) + return responses diff --git a/Quick_Deploy/JAX/model_repository/resnet50/config.pbtxt b/Quick_Deploy/JAX/model_repository/resnet50/config.pbtxt new file mode 100644 index 00000000..809a4bdd --- /dev/null +++ b/Quick_Deploy/JAX/model_repository/resnet50/config.pbtxt @@ -0,0 +1,50 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "resnet50" +backend: "python" +max_batch_size: 8 + +input [ + { + name: "image" + data_type: TYPE_FP32 + dims: [-1, -1, -1] + } +] +output [ + { + name: "fc_out" + data_type: TYPE_FP32 + dims: [-1, -1] + } +] + +instance_group [ + { + kind: KIND_GPU + } +] diff --git a/Quick_Deploy/ONNX/README.md b/Quick_Deploy/ONNX/README.md index a21046c7..8c2711a6 100644 --- a/Quick_Deploy/ONNX/README.md +++ b/Quick_Deploy/ONNX/README.md @@ -53,7 +53,7 @@ wget -O model_repository/densenet_onnx/1/model.onnx \ docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:-py3 tritonserver --model-repository=/models ``` -## Step 3: Using a Triton Client to Query the Server +## Step 2: Using a Triton Client to Query the Server Install dependencies & download an example image to test inference.