Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/advance/grafana_prometheus.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** `https://github.com/meituan-search`

Last updated: 12/05/2025.
Last updated: 11/02/2026.

Monitor the rollout computation process using Prometheus and Grafana when using verl to enhance system observability and facilitate further performance optimization.

Expand Down Expand Up @@ -184,6 +184,21 @@ After task execution, log in to Grafana to view and customize monitoring dashboa
- [vLLM Grafana Dashboard style 2](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/dashboards/grafana/query_statistics.json)
- [SGLang Grafana Dashboard](https://github.com/sgl-project/sglang/blob/main/examples/monitoring/grafana/dashboards/json/sglang-dashboard.json)

## Logging Prometheus Metrics to Experiment Tracking

You can automatically log Prometheus metrics to your experiment tracking backends (WandB, TensorBoard, MLflow, etc.) during training:

```yaml
actor_rollout_ref:
rollout:
prometheus:
enable: True
metrics_to_log:
- "vllm:generation_tokens_total"
```

Metrics are queried every training step and logged with the `rollout/` prefix (e.g., `rollout/generation_tokens_total`).

## Additional Resources

- [Ray Monitoring Documentation](https://docs.ray.io/en/latest/cluster/configure-manage-dashboard.html)
Expand Down
182 changes: 182 additions & 0 deletions verl/experimental/agent_loop/prometheus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

import logging
import os
import time

import ray
import requests
import yaml

from verl.workers.config.rollout import PrometheusConfig
Expand Down Expand Up @@ -108,3 +110,183 @@ def reload_prometheus(port):

except Exception as e:
logger.error(f"Failed to update Prometheus configuration: {e}")


class PrometheusClient:
"""Client for querying Prometheus metrics during training.

This client queries Prometheus running on the Ray head node to fetch
infrastructure metrics (GPU cache usage, throughput, etc.) and makes
them available for experiment tracking.

Features:
- Automatic head node discovery via Ray
- Retry logic with exponential backoff
- Per-metric error handling (one failure doesn't affect others)
- Result caching to reduce query frequency

Attributes:
host: Prometheus host (Ray head node IP)
port: Prometheus port (default 9090)
metrics_to_log: List of Prometheus metric names or queries
timeout: HTTP request timeout in seconds
max_attempts: Maximum retry attempts per metric
retry_delay: Base delay between retries
cache_duration: How long to cache results (seconds)
"""

DEFAULT_TIMEOUT = 5.0
DEFAULT_MAX_ATTEMPTS = 2
DEFAULT_RETRY_DELAY = 0.5
DEFAULT_CACHE_DURATION = 10.0

def __init__(
self,
prometheus_config: PrometheusConfig,
timeout: float = DEFAULT_TIMEOUT,
max_attempts: int = DEFAULT_MAX_ATTEMPTS,
retry_delay: float = DEFAULT_RETRY_DELAY,
cache_duration: float = DEFAULT_CACHE_DURATION,
):
"""Initialize Prometheus client.

Args:
prometheus_config: PrometheusConfig object from rollout config
timeout: HTTP timeout for queries
max_attempts: Number of retry attempts
retry_delay: Base delay for exponential backoff
cache_duration: Cache results for this many seconds
"""
self.port = prometheus_config.port
self.metrics_to_log = prometheus_config.metrics_to_log or []
self.timeout = timeout
self.max_attempts = max_attempts
self.retry_delay = retry_delay
self.cache_duration = cache_duration

self.host = self._get_ray_head_node()
self.base_url = f"http://{self.host}:{self.port}"

self._cache = {}
self._cache_timestamps = {}

if self.metrics_to_log:
logger.info(f"PrometheusClient initialized: {len(self.metrics_to_log)} metrics from {self.base_url}")

def _get_ray_head_node(self) -> str:
"""Get the IP address of the Ray head node where Prometheus runs.

Returns:
str: IP address of head node

Raises:
RuntimeError: If head node cannot be determined
"""
nodes = ray.nodes()
for node in nodes:
if node.get("Alive") and "node:__internal_head__" in node.get("Resources", {}):
return node["NodeManagerAddress"]

for node in nodes:
if node.get("Alive") and node.get("Resources", {}).get("CPU", 0) > 0:
logger.warning(f"Using non-head node for Prometheus: {node['NodeManagerAddress']}")
return node["NodeManagerAddress"]

raise RuntimeError("No alive Ray nodes found")

def _query_metric(self, metric_name: str) -> float | None:
"""Query a single metric from Prometheus with retry logic.

Args:
metric_name: Prometheus metric name or query expression

Returns:
Metric value as float, or None if query failed
"""
if metric_name in self._cache:
age = time.time() - self._cache_timestamps[metric_name]
if age < self.cache_duration:
return self._cache[metric_name]

url = f"{self.base_url}/api/v1/query"
params = {"query": metric_name}

for attempt in range(self.max_attempts):
try:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()

data = response.json()
if data["status"] != "success":
logger.warning(f"Prometheus query failed: {data.get('error', 'unknown')}")
return None

result = data.get("data", {}).get("result", [])
if not result:
logger.debug(f"No data for metric: {metric_name}")
return None

value = float(result[0]["value"][1])

self._cache[metric_name] = value
self._cache_timestamps[metric_name] = time.time()

return value

except requests.exceptions.Timeout:
logger.warning(f"Prometheus query timeout for {metric_name} (attempt {attempt + 1})")
except requests.exceptions.ConnectionError:
logger.warning(f"Prometheus connection error for {metric_name} (attempt {attempt + 1})")
except (ValueError, KeyError, IndexError) as e:
logger.error(f"Failed to parse Prometheus response for {metric_name}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error querying {metric_name}: {e}")
return None

if attempt < self.max_attempts - 1:
time.sleep(self.retry_delay * (2**attempt))

return None

def query_all_metrics(self, prefix: str = "rollout/") -> dict[str, float]:
"""Query all configured metrics from Prometheus.

Returns:
Dictionary mapping metric names to values. Failed queries are omitted.
Keys use a prefix for namespacing in experiment tracking.
Always returns a dict (empty if all queries fail).

Example:
{
"rollout/vllm_gpu_cache_usage_perc": 85.3,
"rollout/vllm_avg_generation_throughput_toks_per_s": 1247.5
}
"""
if not self.metrics_to_log:
return {}

metrics = {}
for metric_name in self.metrics_to_log:
try:
value = self._query_metric(metric_name)
if value is not None:
safe_name = metric_name.replace(":", "_")
metrics[f"{prefix}{safe_name}"] = value
except Exception as e:
logger.warning(f"Unexpected error while querying metric '{metric_name}': {e}", exc_info=True)

return metrics

def clear_cache(self):
"""Clear the metrics cache. Useful for testing or forced refresh."""
self._cache.clear()
self._cache_timestamps.clear()


def get_prometheus_client(prometheus_config: PrometheusConfig) -> PrometheusClient | None:
try:
return PrometheusClient(prometheus_config)
except Exception as e:
logger.warning(f"Failed to initialize PrometheusClient: {e}")
return None
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ actor_rollout_ref:
port: 9090
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
metrics_to_log: null
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ actor_rollout_ref:
port: 9090
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
metrics_to_log: null
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ actor_rollout_ref:
port: 9090
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
metrics_to_log: null
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ prometheus:
# Specify served_model_name to avoid displaying overly long model paths in Grafana
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}

# List of Prometheus metrics to query and log to experiment tracking
metrics_to_log: null

# type of quantization in vllm, currently support fp8 and torchao
quantization: null

Expand Down
12 changes: 12 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from verl import DataProto
from verl.checkpoint_engine import CheckpointEngineManager
from verl.experimental.agent_loop.prometheus_utils import get_prometheus_client
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
Expand Down Expand Up @@ -287,6 +288,12 @@ def __init__(
experiment_name=self.config.trainer.experiment_name,
)

# Initialize Prometheus client if enabled
self.prometheus_client = None
prometheus_config = self.config.actor_rollout_ref.rollout.prometheus
if prometheus_config.enable and prometheus_config.metrics_to_log:
self.prometheus_client = get_prometheus_client(prometheus_config)

# if ref_in_actor is True, the reference policy will be actor without lora applied
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
Expand Down Expand Up @@ -1585,6 +1592,11 @@ def fit(self):
metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm))
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation

# Query Prometheus metrics if enabled
if self.prometheus_client is not None:
prometheus_metrics = self.prometheus_client.query_all_metrics(prefix="rollout/")
metrics.update(prometheus_metrics)

# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
self.train_dataloader.sampler.update(batch=batch)
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class PrometheusConfig(BaseConfig):
file: str = "/tmp/ray/session_latest/metrics/prometheus/prometheus.yml"
# Specify served_model_name to avoid displaying overly long model paths in Grafana
served_model_name: Optional[str] = None
# List of Prometheus metrics to query and log to experiment tracking
metrics_to_log: Optional[list[str]] = None


@dataclass
Expand Down