Skip to content

Commit

Permalink
♻️ move metrics into one file (#97)
Browse files Browse the repository at this point in the history
#### Motivation

This PR accomplishes two things:
- Upgrades the `metrics` crate to the latest version, which causes API
breaking changes
- Duplicates all counter metrics to include `{metric_name}_total` to
align with the prometheus metrics exported by vLLM

#### Modifications

This refactors all usages of the `metrics` crate into a single file, so
that changes to how it's used can be made in one place. This lets us
easily duplicate all the counter metrics.

#### Result

No existing behavior should change, only new `*_total` counters should
be added to the /metrics endpoint.

---------

Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored May 20, 2024
1 parent fb23def commit 9b4aea8
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 376 deletions.
748 changes: 448 additions & 300 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"


## Rust builder ################################################################
# Specific debian version so that compatible glibc version is used
FROM rust:1.77.2-bullseye as rust-builder
# Using bookworm for compilation so the rust binaries get linked against libssl.so.3
FROM rust:1.78-bookworm as rust-builder
ARG PROTOC_VERSION

ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
Expand Down
4 changes: 2 additions & 2 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ text-generation-client = { path = "client" }
clap = { version = "^4.5.4", features = ["derive", "env"] }
futures = "^0.3.30"
flume = "^0.11.0"
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.2", features = [] }
metrics = "0.22.3"
metrics-exporter-prometheus = { version = "0.14.0", features = ["http-listener"] }
moka = { version = "0.12.6", features = ["future"] }
nohash-hasher = "^0.2.0"
num = "^0.4.2"
Expand Down
55 changes: 28 additions & 27 deletions router/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use crate::{
validation::RequestSize,
ErrorResponse, GenerateRequest,
};
use crate::metrics::{increment_counter, increment_labeled_counter, observe_histogram, observe_labeled_histogram, set_gauge};

/// Batcher
#[derive(Clone)]
Expand Down Expand Up @@ -447,9 +448,9 @@ async fn batching_task<B: BatchType>(
batch_size,
);

metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_input_tokens", batch_tokens as f64);
metrics::gauge!(
set_gauge("tgi_batch_current_size", batch_size as f64);
set_gauge("tgi_batch_input_tokens", batch_tokens as f64);
set_gauge(
"tgi_batch_max_remaining_tokens",
batch_max_remaining_tokens.unwrap() as f64
);
Expand Down Expand Up @@ -529,7 +530,7 @@ async fn batching_task<B: BatchType>(
"Extending batch #{} of {} with additional batch #{} of {}",
batch_id, batch_size, new_batch_id, added_batch_size
);
metrics::increment_counter!("tgi_batch_concatenation_count");
increment_counter("tgi_batch_concatenation_count", 1);
}
} else {
combined_batch_id = new_batch_id;
Expand Down Expand Up @@ -560,9 +561,9 @@ async fn batching_task<B: BatchType>(
}
}

metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_input_tokens", 0.0);
metrics::gauge!("tgi_batch_max_remaining_tokens", 0.0);
set_gauge("tgi_batch_current_size", 0.0);
set_gauge("tgi_batch_input_tokens", 0.0);
set_gauge("tgi_batch_max_remaining_tokens", 0.0);
}

info!("Batching loop exiting");
Expand Down Expand Up @@ -625,9 +626,9 @@ impl<'a> TokenProcessor<'a> {
let batch_size = batch.requests.len();
let batch_tokens = batch.total_tokens;
let start_time = Instant::now();
metrics::histogram!("tgi_batch_next_tokens", batch_tokens as f64);
metrics::histogram!(
"tgi_batch_inference_batch_size", batch_size as f64, "method" => "prefill"
observe_histogram("tgi_batch_next_tokens", batch_tokens as f64);
observe_labeled_histogram(
"tgi_batch_inference_batch_size", &[("method", "prefill")], batch_size as f64
);
let (result, prefill_time) = self
._wrap_future(
Expand All @@ -648,8 +649,8 @@ impl<'a> TokenProcessor<'a> {
batches: Vec<CachedBatch>,
queue: &mut Queue<B>,
) -> (Option<CachedBatch>, Duration) {
metrics::histogram!(
"tgi_batch_inference_batch_size", self.entries.len() as f64, "method" => "next_token"
observe_labeled_histogram(
"tgi_batch_inference_batch_size", &[("method", "next_token")], self.entries.len() as f64
);
let start_time = Instant::now();
self._wrap_future(
Expand All @@ -672,7 +673,7 @@ impl<'a> TokenProcessor<'a> {
start_id: Option<u64>,
queue: &mut Queue<B>,
) -> (Option<CachedBatch>, Duration) {
metrics::increment_counter!("tgi_batch_inference_count", "method" => method);
increment_labeled_counter("tgi_batch_inference_count", &[("method", method)], 1);

// We process the shared queue while waiting for the response from the python shard(s)
let queue_servicer = queue.service_queue().fuse();
Expand All @@ -692,27 +693,27 @@ impl<'a> TokenProcessor<'a> {
let completed_request_ids = self.process_next_tokens(generated_tokens, errors);
// Update health
self.generation_health.store(true, Ordering::SeqCst);
metrics::histogram!(
observe_labeled_histogram(
"tgi_batch_inference_duration",
&[("method", method),
("makeup", "single_only")],
elapsed.as_secs_f64(),
"method" => method,
"makeup" => "single_only", // later will possibly be beam_only or mixed
);
metrics::histogram!(
observe_labeled_histogram(
"tgi_batch_inference_forward_duration",
forward_duration,
"method" => method,
"makeup" => "single_only", // later will possibly be beam_only or mixed
&[("method", method),
("makeup", "single_only")],
forward_duration.as_secs_f64(),
);
metrics::histogram!(
observe_labeled_histogram(
"tgi_batch_inference_tokproc_duration",
&[("method", method),
("makeup", "single_only")],
pre_token_process_time.elapsed().as_secs_f64(),
"method" => method,
"makeup" => "single_only", // later will possibly be beam_only or mixed
);
// Probably don't need this additional counter because the duration histogram
// records a total count
metrics::increment_counter!("tgi_batch_inference_success", "method" => method);
increment_labeled_counter("tgi_batch_inference_success", &[("method", method)], 1);
Some(CachedBatch {
batch_id: next_batch_id,
status: completed_request_ids.map(|c| RequestsStatus { completed_ids: c }),
Expand All @@ -729,7 +730,7 @@ impl<'a> TokenProcessor<'a> {
ClientError::Connection(_) => "connection",
_ => "error",
};
metrics::increment_counter!("tgi_batch_inference_failure", "method" => method, "reason" => reason);
increment_labeled_counter("tgi_batch_inference_failure", &[("method", method), ("reason", reason)], 1);
self.send_errors(err, start_id);
None
}
Expand Down Expand Up @@ -980,7 +981,7 @@ impl<'a> TokenProcessor<'a> {
// If receiver closed (request cancelled), cancel this entry
let e = self.entries.remove(&request_id).unwrap();
stop_reason = Cancelled;
metrics::increment_counter!("tgi_request_failure", "err" => "cancelled");
increment_labeled_counter("tgi_request_failure", &[("err", "cancelled")], 1);
//TODO include request context in log message
warn!(
"Aborted streaming request {request_id} cancelled by client \
Expand All @@ -994,7 +995,7 @@ impl<'a> TokenProcessor<'a> {
// If receiver closed (request cancelled), cancel this entry
let e = self.entries.remove(&request_id).unwrap();
stop_reason = Cancelled;
metrics::increment_counter!("tgi_request_failure", "err" => "cancelled");
increment_labeled_counter("tgi_request_failure", &[("err", "cancelled")], 1);
//TODO include request context in log message
warn!(
"Aborted request {request_id} cancelled by client \
Expand Down
55 changes: 26 additions & 29 deletions router/src/grpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{
validation::{RequestSize, ValidationError},
GenerateParameters, GenerateRequest,
};
use crate::metrics::{increment_counter, increment_labeled_counter, observe_histogram};
use crate::pb::fmaas::tokenize_response::Offset;

/// Whether to fail if sampling parameters are provided in greedy-mode requests
Expand Down Expand Up @@ -67,8 +68,6 @@ pub(crate) async fn start_grpc_server<F: Future<Output = ()> + Send + 'static>(
let grpc_service = GenerationServicer {
state: shared_state,
tokenizer,
input_counter: metrics::register_counter!("tgi_request_input_count"),
tokenize_input_counter: metrics::register_counter!("tgi_tokenize_request_input_count"),
};
let grpc_server = builder
.add_service(GenerationServiceServer::new(grpc_service))
Expand All @@ -92,8 +91,6 @@ async fn load_pem(path: String, name: &str) -> Vec<u8> {
pub struct GenerationServicer {
state: ServerState,
tokenizer: AsyncTokenizer,
input_counter: metrics::Counter,
tokenize_input_counter: metrics::Counter,
}

#[tonic::async_trait]
Expand Down Expand Up @@ -124,20 +121,20 @@ impl GenerationService for GenerationServicer {
let br = request.into_inner();
let batch_size = br.requests.len();
let kind = if batch_size == 1 { "single" } else { "batch" };
metrics::increment_counter!("tgi_request_count", "kind" => kind);
increment_labeled_counter("tgi_request_count", &[("kind", kind)], 1);
if batch_size == 0 {
return Ok(Response::new(BatchedGenerationResponse {
responses: vec![],
}));
}
self.input_counter.increment(batch_size as u64);
increment_counter("tgi_request_input_count", batch_size as u64);
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self
.state
.limit_concurrent_requests
.try_acquire_many(batch_size as u32)
.map_err(|_| {
metrics::increment_counter!("tgi_request_failure", "err" => "conc_limit");
increment_labeled_counter("tgi_request_failure", &[("err", "conc_limit")], 1);
tracing::error!("Model is overloaded");
Status::resource_exhausted("Model is overloaded")
})?;
Expand Down Expand Up @@ -217,11 +214,11 @@ impl GenerationService for GenerationServicer {
}
.map_err(|err| match err {
InferError::RequestQueueFull() => {
metrics::increment_counter!("tgi_request_failure", "err" => "queue_full");
increment_labeled_counter("tgi_request_failure", &[("err", "queue_full")], 1);
Status::resource_exhausted(err.to_string())
}
_ => {
metrics::increment_counter!("tgi_request_failure", "err" => "generate");
increment_labeled_counter("tgi_request_failure", &[("err", "generate")], 1);
tracing::error!("{err}");
Status::from_error(Box::new(err))
}
Expand Down Expand Up @@ -254,15 +251,15 @@ impl GenerationService for GenerationServicer {
) -> Result<Response<Self::GenerateStreamStream>, Status> {
let start_time = Instant::now();
let request = request.extract_context();
metrics::increment_counter!("tgi_request_count", "kind" => "stream");
self.input_counter.increment(1);
increment_labeled_counter("tgi_request_count", &[("kind", "stream")], 1);
increment_counter("tgi_request_input_count", 1);
let permit = self
.state
.limit_concurrent_requests
.clone()
.try_acquire_owned()
.map_err(|_| {
metrics::increment_counter!("tgi_request_failure", "err" => "conc_limit");
increment_labeled_counter("tgi_request_failure", &[("err", "conc_limit")], 1);
tracing::error!("Model is overloaded");
Status::resource_exhausted("Model is overloaded")
})?;
Expand Down Expand Up @@ -292,7 +289,7 @@ impl GenerationService for GenerationServicer {
|ctx, count, reason, request_id, times, out, err| {
let _enter = ctx.span.enter();
if let Some(e) = err {
metrics::increment_counter!("tgi_request_failure", "err" => "generate");
increment_labeled_counter("tgi_request_failure", &[("err", "generate")], 1);
tracing::error!(
"Streaming response failed after {count} tokens, \
output so far: '{:?}': {e}",
Expand Down Expand Up @@ -322,11 +319,11 @@ impl GenerationService for GenerationServicer {
.await
.map_err(|err| match err {
InferError::RequestQueueFull() => {
metrics::increment_counter!("tgi_request_failure", "err" => "queue_full");
increment_labeled_counter("tgi_request_failure", &[("err", "queue_full")], 1);
Status::resource_exhausted(err.to_string())
}
_ => {
metrics::increment_counter!("tgi_request_failure", "err" => "unknown");
increment_labeled_counter("tgi_request_failure", &[("err", "unknown")], 1);
tracing::error!("{err}");
Status::from_error(Box::new(err))
}
Expand All @@ -341,9 +338,9 @@ impl GenerationService for GenerationServicer {
request: Request<BatchedTokenizeRequest>,
) -> Result<Response<BatchedTokenizeResponse>, Status> {
let br = request.into_inner();
metrics::increment_counter!("tgi_tokenize_request_count");
increment_counter("tgi_tokenize_request_count", 1);
let start_time = Instant::now();
self.tokenize_input_counter.increment(br.requests.len() as u64);
increment_counter("tgi_tokenize_request_input_count", br.requests.len() as u64);

let truncate_to = match br.truncate_input_tokens {
0 => u32::MAX,
Expand Down Expand Up @@ -378,8 +375,8 @@ impl GenerationService for GenerationServicer {
.await?;

let token_total: u32 = responses.iter().map(|tr| tr.token_count).sum();
metrics::histogram!("tgi_tokenize_request_tokens", token_total as f64);
metrics::histogram!(
observe_histogram("tgi_tokenize_request_tokens", token_total as f64);
observe_histogram(
"tgi_tokenize_request_duration",
start_time.elapsed().as_secs_f64()
);
Expand Down Expand Up @@ -428,12 +425,12 @@ impl GenerationServicer {
Err(err) => Err(err),
}
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
increment_labeled_counter("tgi_request_failure", &[("err", "validation")], 1);
tracing::error!("{err}");
Status::invalid_argument(err.to_string())
})
.map(|requests| {
metrics::histogram!(
observe_histogram(
"tgi_request_validation_duration",
start_time.elapsed().as_secs_f64()
);
Expand Down Expand Up @@ -474,27 +471,27 @@ fn log_response(
span.record("total_time", format!("{total_time:?}"));
span.record("input_toks", input_tokens);

metrics::histogram!(
observe_histogram(
"tgi_request_inference_duration",
inference_time.as_secs_f64()
);
metrics::histogram!(
observe_histogram(
"tgi_request_mean_time_per_token_duration",
time_per_token.as_secs_f64()
);
}

// Metrics
match reason {
Error => metrics::increment_counter!("tgi_request_failure", "err" => "generate"),
Error => increment_labeled_counter("tgi_request_failure", &[("err", "generate")], 1),
Cancelled => (), // recorded where cancellation is detected
_ => {
metrics::increment_counter!(
"tgi_request_success", "stop_reason" => reason.as_str_name(), "kind" => kind
increment_labeled_counter(
"tgi_request_success", &[("stop_reason", reason.as_str_name()), ("kind", kind)], 1
);
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", generated_tokens as f64);
metrics::histogram!(
observe_histogram("tgi_request_duration", total_time.as_secs_f64());
observe_histogram("tgi_request_generated_tokens", generated_tokens as f64);
observe_histogram(
"tgi_request_total_tokens",
(generated_tokens as usize + input_tokens) as f64
);
Expand Down
1 change: 1 addition & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod server;
mod tokenizer;
mod validation;
mod tracing;
mod metrics;

use batcher::Batcher;
use serde::{Deserialize, Serialize};
Expand Down
34 changes: 34 additions & 0 deletions router/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Small helpers for using the metrics crate.
// This aims to collect all usages of the metrics crate so that future api-breaking changes can be handled in one place.


// These counter helper methods will actually increment a second counter with `_total` appended to the name.
// This is for compatibility with other runtimes that use prometheus directly, which is very
// opinionated that all counters should end with the suffix _total.
// Cite: https://prometheus.github.io/client_python/instrumenting/counter/

pub fn increment_counter(name: &'static str, value: u64) {
metrics::counter!(name).increment(value);
metrics::counter!(format!("{name}_total")).increment(value);
}


pub fn increment_labeled_counter(name: &'static str, labels: &[(&'static str, &'static str)], value: u64) {
metrics::counter!(name, labels).increment(value);
metrics::counter!(format!("{name}_total"), labels).increment(value);
}


pub fn set_gauge(name: &'static str, value: f64) {
metrics::gauge!(name).set(value);
}


pub fn observe_histogram(name: &'static str, value: f64) {
metrics::histogram!(name).record(value);
}


pub fn observe_labeled_histogram(name: &'static str, labels: &[(&'static str, &'static str)], value: f64) {
metrics::histogram!(name, labels).record(value);
}
Loading

0 comments on commit 9b4aea8

Please sign in to comment.