Skip to content

Commit 439fcaf

Browse files
feat(router): add prometheus metrics scrape endpoint (#71)
1 parent 7b3d460 commit 439fcaf

File tree

8 files changed

+239
-33
lines changed

8 files changed

+239
-33
lines changed

Cargo.lock

Lines changed: 120 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

router/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0"
1919
text-generation-client = { path = "client" }
2020
clap = { version = "4.1.4", features = ["derive", "env"] }
2121
futures = "0.3.26"
22+
metrics = "0.20.1"
23+
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
2224
nohash-hasher = "0.2.0"
2325
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
2426
opentelemetry-otlp = "0.11.0"

router/src/infer.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
33
use crate::GenerateRequest;
44
use crate::{Entry, Queue, Token};
55
use nohash_hasher::IntMap;
6-
use std::future::Future;
76
use std::sync::Arc;
87
use text_generation_client::{
98
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
@@ -81,6 +80,7 @@ impl Infer {
8180
.limit_concurrent_requests
8281
.try_acquire_owned()
8382
.map_err(|err| {
83+
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
8484
tracing::error!("{err}");
8585
err
8686
})?;
@@ -172,6 +172,7 @@ impl Infer {
172172
})
173173
} else {
174174
let err = InferError::IncompleteGeneration;
175+
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
175176
tracing::error!("{err}");
176177
Err(err)
177178
}
@@ -201,7 +202,7 @@ async fn batching_task(
201202
// This batch might be smaller than the maximum batch size if there are not enough requests
202203
// waiting in the queue
203204
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
204-
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
205+
let mut cached_batch = prefill(&mut client, batch, &mut entries)
205206
.instrument(span)
206207
.await;
207208
let mut waiting_tokens = 1;
@@ -212,6 +213,7 @@ async fn batching_task(
212213
// Get current batch info
213214
let batch_size = batch.size;
214215
let mut batches = vec![batch];
216+
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
215217

216218
// If the current batch is too small, we try to add more requests to it
217219
if batch_size <= limit_min_batch_size {
@@ -241,10 +243,9 @@ async fn batching_task(
241243
});
242244

243245
// Generate one token for this new batch to have the attention past in cache
244-
let new_cached_batch =
245-
wrap_future(client.prefill(new_batch), &mut new_entries)
246-
.instrument(span)
247-
.await;
246+
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
247+
.instrument(span)
248+
.await;
248249
// Reset waiting counter
249250
waiting_tokens = 1;
250251
// Extend current batch with the new batch
@@ -268,29 +269,59 @@ async fn batching_task(
268269
entry.temp_span = Some(entry_batch_span);
269270
});
270271

271-
cached_batch = wrap_future(client.decode(batches), &mut entries)
272+
cached_batch = decode(&mut client, batches, &mut entries)
272273
.instrument(next_batch_span)
273274
.await;
274275
waiting_tokens += 1;
275276
}
277+
metrics::gauge!("tgi_batch_current_size", 0.0);
276278
}
277279
}
278280
}
279281

280-
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
281282
#[instrument(skip_all)]
282-
async fn wrap_future(
283-
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
283+
async fn prefill(
284+
client: &mut ShardedClient,
285+
batch: Batch,
284286
entries: &mut IntMap<u64, Entry>,
285287
) -> Option<Batch> {
286-
match future.await {
288+
let start_time = Instant::now();
289+
290+
match client.prefill(batch).await {
291+
Ok((generations, next_batch)) => {
292+
send_generations(generations, entries);
293+
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
294+
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
295+
next_batch
296+
}
297+
// If we have an error, we discard the whole batch
298+
Err(err) => {
299+
send_errors(err, entries);
300+
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
301+
None
302+
}
303+
}
304+
}
305+
306+
#[instrument(skip_all)]
307+
async fn decode(
308+
client: &mut ShardedClient,
309+
batches: Vec<Batch>,
310+
entries: &mut IntMap<u64, Entry>,
311+
) -> Option<Batch> {
312+
let start_time = Instant::now();
313+
314+
match client.decode(batches).await {
287315
Ok((generations, next_batch)) => {
288316
send_generations(generations, entries);
317+
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
318+
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
289319
next_batch
290320
}
291321
// If we have an error, we discard the whole batch
292322
Err(err) => {
293323
send_errors(err, entries);
324+
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
294325
None
295326
}
296327
}
@@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
303334
// Create and enter a span to link this function back to the entry
304335
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
305336
let err = InferError::GenerationError(error.to_string());
337+
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
306338
tracing::error!("{err}");
307339

308340
// unwrap_or is valid here as we don't care if the receiver is gone.

router/src/queue.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ impl State {
132132
// Push entry in the queue
133133
self.entries.push((self.next_id, entry));
134134
self.next_id += 1;
135+
metrics::increment_gauge!("tgi_queue_size", 1.0);
135136
}
136137

137138
// Get the next batch
@@ -190,6 +191,8 @@ impl State {
190191
// Increment batch id
191192
self.next_batch_id += 1;
192193

194+
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
195+
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
193196
Some((batch_entries, batch, next_batch_span))
194197
}
195198
}

0 commit comments

Comments
 (0)