Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OP_REQUIRES failed at xla_ops : UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND #2214

Open
kmkolasinski opened this issue Apr 7, 2024 · 8 comments

Comments

@kmkolasinski
Copy link

Bug Report

Does Tensorflow Serving support XLA compiled SavedModels ? or am I doing something wrong ?

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Dockerfile
  • TensorFlow Serving installed from (source or binary): No
  • TensorFlow Serving version: 2.13.1-gpu

Describe the problem

Hi, I'm trying to run XLA compiled models via Tensorflow Serving, however it seems to not work for me.

Here is the notebook I used to create XLA/AMP compiled SavedModel of very simple classifiers like ResNet50
https://github.com/kmkolasinski/triton-saved-model/blob/main/notebooks/export-classifier.ipynb

When running the TFServing server I can see following warning in the console

tf_serving_server-1  | 2024-04-07 07:57:23.842168: W external/org_tensorflow/tensorflow/core/framework/op_kernel.cc:1828] OP_REQUIRES failed at xla_ops.cc:503 : UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?

I get similar message on the client side

_InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.UNIMPLEMENTED
	details = "2 root error(s) found.
  (0) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?
	 [[{{function_node __inference_predict_images_90382}}{{node StatefulPartitionedCall}}]]
	 [[cluster_6_1/merge_oidx_0/_655]]
  (1) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?
	 [[{{function_node __inference_predict_images_90382}}{{node StatefulPartitionedCall}}]]
0 successful operations.
0 derived errors ignored."
	debug_error_string = "UNKNOWN:Error received from peer ipv6:%5B::1%5D:8500 {grpc_message:"2 root error(s) found.\n  (0) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?\n\t [[{{function_node __inference_predict_images_90382}}{{node StatefulPartitionedCall}}]]\n\t [[cluster_6_1/merge_oidx_0/_655]]\n  (1) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?\n\t [[{{function_node __inference_predict_images_90382}}{{node StatefulPartitionedCall}}]]\n0 successful operations.\n0 derived errors ignored.", grpc_status:12, created_time:"2024-04-07T09:57:23.84265796+02:00"}"
>

Exact Steps to Reproduce

You can find my repo where I compare Triton Server (python backend) with TFServing here: https://github.com/kmkolasinski/triton-saved-model. In the notebooks directory you will find

  • a notebook which created SavedModel (with XLA and without)
  • another notebook which runs predictions
@dataclass(frozen=True)
class TFServingGRPCClient:
    model_name: str
    url: str = "localhost:8500"

    def predict(
        self, signature: str, *, images: np.ndarray, timeout: int = 60
    ) -> predict_pb2.PredictResponse:

        options = [("grpc.max_receive_message_length", 1 << 30)]
        channel = grpc.insecure_channel(self.url, options=options)
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

        grpc_request = predict_pb2.PredictRequest()
        grpc_request.model_spec.name = self.model_name
        grpc_request.model_spec.signature_name = signature

        grpc_request.inputs["images"].CopyFrom(tf.make_tensor_proto(images))

        predict_response = stub.Predict(grpc_request, timeout)
        channel.close()
        return predict_response

images = tf.random.uniform((100, 224, 224, 3), minval=0, maxval=255, dtype=tf.int32)
images = tf.cast(images, tf.uint8)

client = TFServingGRPCClient("resnet50-no-opt")
result = client.predict("predict_images", images=images)
# works correcltly

client = TFServingGRPCClient("resnet50-xla-amp")
result = client.predict("predict_images", images=images)
# fails with UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND:

Is this expected behavior ? I am aware of this flag

 --xla_cpu_compilation_enabled=false     bool    EXPERIMENTAL; CAN BE REMOVED ANYTIME! Enable XLA:CPU JIT (default is disabled). With XLA:CPU JIT disabled, models utilizing this feature will return bad Status on first compilation request.

however, I was not able to find any reasonable resources on how to use it, to test my case.

@singhniraj08 singhniraj08 self-assigned this Apr 22, 2024
@singhniraj08
Copy link

@kmkolasinski, --xla_cpu_compilation_enabled=true parameter should be passed as additional argument to enable XLA:CPU JIT (default is disabled).
Can you try creating a TF Serving docker container with additional parameters as shown in example and see if model inferencing works. Thank you!

@kmkolasinski
Copy link
Author

Hi @singhniraj08, yes I tried this appraoch, please check out my docker (at CMD command) which I used for testing: https://github.com/kmkolasinski/triton-saved-model/blob/main/tf_serving/Dockerfile

Here is my docker compose which I used to run serving: https://github.com/kmkolasinski/triton-saved-model/blob/main/docker-compose.yml

# Firstly, use https://github.com/kmkolasinski/triton-saved-model/blob/main/notebooks%2Fexport-classifier.ipynb to export various classifiers
docker compose up tf_serving_server

I prepared this repository which reproduces this issue https://github.com/kmkolasinski/triton-saved-model/tree/main

@kmkolasinski
Copy link
Author

Hi @YanghuaHuang did you have time to take a look at this issue ?

@YanghuaHuang
Copy link
Member

Sorry for the late reply. Assign to @guanxinq to triage, who has better knowledge on this.

@kmkolasinski
Copy link
Author

Hi @YanghuaHuang thanks, I just wonder whether we can use XLA compiled models in TF Serving or not. If yes, how we can achieve it as I couldn't find any information about this.

@YanghuaHuang
Copy link
Member

I think tfs does support XLA CPU but not GPU. But I could be wrong.

@gharibian Hey Dero, can you help on this? Thanks!

@kmkolasinski
Copy link
Author

Thanks for the answer. If this is a truth, the message

Could not find compiler for platform CUDA: NOT_FOUND

makes a perfect sense to me now and that's a pitty. I assumed that TF Serving is using the same C++ backend to run SavedModel graph as TF libraries, so any SavedModel which I can run via python code I can also run via TF Serving. Let's wait for the confirmation from @gharibian .

@kmkolasinski
Copy link
Author

hey @gharibian did you have time to take a look at this thread ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants