Skip to content

Commit

Permalink
support prometheus metrics (sgl-project#1853)
Browse files Browse the repository at this point in the history
Co-authored-by: Lianmin Zheng <[email protected]>
Co-authored-by: Byron Hsu <[email protected]>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent f5113e5 commit a146d99
Show file tree
Hide file tree
Showing 7 changed files with 526 additions and 3 deletions.
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import dataclasses
import logging
import time
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -254,6 +255,16 @@ def __init__(
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object

# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None

# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
Expand Down Expand Up @@ -1028,6 +1039,9 @@ def __str__(self):
f"#req={(len(self.reqs))})"
)

def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()

@dataclasses.dataclass
class ModelWorkerBatch:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import random
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
Expand Down Expand Up @@ -306,6 +307,7 @@ def add_one_req(self, req: Req):
):
# Non-chunked prefill
self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len,
Expand All @@ -324,6 +326,7 @@ def add_one_req(self, req: Req):
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
Expand Down
112 changes: 109 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
Expand Down Expand Up @@ -222,7 +224,8 @@ def __init__(
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval

# Init chunked prefill
Expand Down Expand Up @@ -291,6 +294,15 @@ def __init__(
],
with_stack=True,
)
# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)

def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
Expand Down Expand Up @@ -338,6 +350,11 @@ def event_loop_normal(self):
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()

self.last_batch = batch

Expand Down Expand Up @@ -476,6 +493,7 @@ def handle_generate_request(
self.max_req_len - len(req.origin_input_ids) - 1,
)

req.created_time = time.time()
self.waiting_queue.append(req)

def handle_embedding_request(
Expand Down Expand Up @@ -504,9 +522,11 @@ def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
Expand Down Expand Up @@ -676,6 +696,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)

if num_mixed_running > 0:
logger.info(
Expand Down Expand Up @@ -770,6 +793,7 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
Expand All @@ -789,6 +813,88 @@ def run_batch(self, batch: ScheduleBatch):
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid
return ret
def get_stats(self,batch: ScheduleBatch):
# TODO: get stats for chunked prefill

now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0

# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate=self.stats.cache_hit_rate
token_usage=self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []

# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []

# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2)

for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now-self.last_stats_tic)

if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None)

return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)

def log_stats(self,stats:Stats):
self.metrics_collector.log_stats(stats)

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
Expand Down
Loading

0 comments on commit a146d99

Please sign in to comment.