Skip to content

Auto-batching inference requests for better resource utilization. Axum-powered REST API wrapping text-embeddings-inference service

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

rustworthy/auto-batching-proxy

Repository files navigation

Auto-batching proxy for inference requests

Description

We've implemented an auto-batching proxy service that would serve as a wrapper over another inference service (see Makefile for details on how we are launching that service for development purposes). Internally, that batching proxy will 🥁 batch individual embedding requests, while for the end-user the API is the same, as if they were the only client of the inference service. This batching makes requests to the upstream service more efficient (and helps reduce costs).

In our hometown in 1990-2000s, there used to be drivers hanging around the realway station who - if you missed your train or just did not bother to buy a ticket, would offer you a ride to another town - but a shared ride. They would gather (batch) a few fellas like ourselves and then start the ride. But there were rules - they could not take more that N people (depending on the vehicle size) and one person who came first could not wait for too long (like no longer than an hour normally).

In the similar fashion, in our batching service here we got MAX_BATCH_SIZE and MAX_WAIT_TIME (in millis) parameters configurable via the environment (see other configurable options in .env.example).

Solution

Stack

Our REST API wrapper is powered by axum web-framework, which is our framework of choice. We have not added Openapi definitions to this project, but if we need to, we will integrate utoipa crate. All other crates we depend on are pretty standard.

Interface

Our REST API wrapper is powered by axum and currently provides one single endpoint POST /embed and so trying it our is super straightforward:

curl 127.0.0.1:8081/embed -X POST -d '{"inputs":["What is Vector Search?", "Hello, world!"]}' -H 'Content-Type: application/json'

Notice how it looks exactly the same (but for PORT number) if you were querying the upstream service directly:

curl 127.0.0.1:8080/embed -X POST -d '{"inputs":["What is Vector Search?", "Hello, world!"]}' -H 'Content-Type: application/json'

The endpoint is currently unprotected and we may want to revisit this and add API key athentication and throttling.

Handler

On the app's start-up, we are launching a task with the web-server and a dedicated task for our inference service worker - which is effectively an actor responsible for inference and hiding the implementation details from the axum handler to separate concerns. Instances of the handler (threads processing end-users' requests) are communicating with that worker using channels and messages. Once a handler receives a request it forwards it to the worker and awaits the worker's response (embeddings or an error) via a oneshot channel, and once it gets the response, it sends it's JSON representation to the end-user.

Worker

The worker just listens for messages from the axum handlers. The worker keeps some state: it has got a message queue with a capacity as per MAX_BATCH_SIZE and a timeout as per MAX_WAIT_TIME - whichever comes first will make the worker send the batch to the upstream service. If an error is received from the upstream inference service, it gets "broadcast" to the handlers. If the embeddings are received, the worker will make sure to not send the entirety of it to each handler, rather only the segment that corresponds to the handler's inputs. We are relying here on the fact that the upstream service returns an array of embeddings in which an embdedding at index N is the result for the query at index N in the inputs container in our request.

To give a concrete example, imagine the batch size is set to 2, and the first request contains inputs array ["hello", "world"] while the second request has only one item ["bye"] - the worker will flatten these two into one array and send to the upstream service as ["hello", "world", "bye"]. The response our worker gets will have the following shape: [[-0.045, ... , -0.123144], [0.412, ..., -0.412], [0.1241, ..., 0.123]]. The worker still "remembers" at that point that it needs to send 2 embeddings to the first handler and 1 embeddig to the second handler instance.

Also - replaying the example above - if the batch size is set to 2, the worker received a message from one handler ["hello", "world"] and the time-out (configured via MAX_WAIT_TIME) is reached, the worker will send send ["hello", "world"] to the upstream server.

Each time the batch get "flushed", the timeout gets unset and the queue gets emptied.

Demo

NB: make sure you got GNU Make, and docker installed.

Populate your very own local .env file with:

make dotenv

You can now launch the auto-batching proxy together with the inference service with a single command:

docker compose up --build

The command above will build our proxy app, launch the upstream inference service first, make sure it is ready, and then launch the proxy app. The initial image build takes some time plus the model need some warm up, so the "cold" start can take up to a few minutes.

If you tweak MAX_WAIT_TIME and MAX_BATCH_SIZE parameters in your .env file, make sure to restart the containers.

Benchmarking

We've set MAX_WAIT_TIME to 1000 (1 second) and MAX_BATCH_SIZE to 8 (the upstream service's text embedding router batch cap), and RUST_LOG set tot "auto_batching_proxy=error,axum=error".

We then launched the services as described above and used the oha utility to generate some load.

With proxy

The command used (see load target in Makefile):

oha -c 200 -z 30s --latency-correction -m POST -d '{"inputs":["What is Vector Search?", "Hello, world!"]}' -H 'Content-Type: application/json' http://localhost:8081/embed

Which gave the following results:

  Success rate: 100.00%
  Total:        30.0039 sec
  Slowest:      2.0037 sec
  Fastest:      0.2129 sec
  Average:      1.6135 sec
  Requests/sec: 126.6503

  Total data:   62.68 MiB
  Size/request: 17.83 KiB
  Size/sec:     2.09 MiB

Without proxy

We've used same utility on the same hardware and some max batch size and max wait, but specified the upstream service's port in the command for direct communitation. The command used (note the port number and see how we are mapping to this host port in our compose and also take a look at load/noproxy target in Makefile)):

oha -c 200 -z 30s --latency-correction -m POST -d '{"inputs":["What is Vector Search?", "Hello, world!"]}' -H 'Content-Type: application/json' http://localhost:8080/embed
  Success rate: 100.00%
  Total:        30.0047 sec
  Slowest:      2.1063 sec
  Fastest:      0.0452 sec
  Average:      1.6371 sec
  Requests/sec: 124.8803

  Total data:   64.19 MiB
  Size/request: 18.53 KiB
  Size/sec:     2.14 MiB

Observations

The reports above are examples from one single test run. In general - upon a few load test runs - we are observing pretty close request per second indicator. Also the slowest requests are pretty close to each other, while the fastest request without proxy is 2.5x faster (~30-100ms vs ~100-200ms), i.e. our wrapper does introduce some overhead. Apparently, we are compensating for this with the gains elsewhere - in the resources savings on the upstream service size and reduced costs for each individual user.

Subscribing for debug and trace events and writing those to stdout slows our application down (~20% bandwidth reduction), so we ended up testing with error+ events level.

We also tried loading our auto-batching proxy with MAX_BATCH_SIZE set to 1 (and all other parameters the same), which gave us results close to those without proxy. Here are stats from one of the runs:

  Success rate: 100.00%
  Total:        30.0061 sec
  Slowest:      2.0051 sec
  Fastest:      0.0672 sec
  Average:      1.5677 sec
  Requests/sec: 129.8068

  Total data:   65.14 MiB
  Size/request: 18.05 KiB
  Size/sec:     2.17 MiB

Which checks out: with the current implementation, the 8th client in the proxied scenario with 8 messages per batch will wait till the preceding 7 clients get their slices of the upstream inference service response. We could play around this and try and improve implementation so reduce the proxy overhead.

Dev Setup

Make sure you got cargo, GNU Make, and docker installed, and hit:

make setup

You should now be able to start the back-end in watch mode with:

make watch

You can send requests with:

curl 127.0.0.1:8081/embed -X POST -d '{"inputs":["What is Vector Search?", "Hello, world!"]}' -H 'Content-Type: application/json'

You can also tweak configurations in the generated .env file (gets populated via make setup), the dev-server will restart automatically (if you are using the make watch command as described above).

About

Auto-batching inference requests for better resource utilization. Axum-powered REST API wrapping text-embeddings-inference service

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •