diff --git a/dashboard/app.py b/dashboard/app.py
index 2945147..5bbbfb2 100644
--- a/dashboard/app.py
+++ b/dashboard/app.py
@@ -65,7 +65,7 @@ def main():
st.markdown("---")
results_dir = st.text_input(
- "测试结果目录", value="../output", help="包含 JSON/CSV 测试结果的目录"
+ "测试结果目录", value="./output", help="包含 JSON/CSV 测试结果的目录"
)
if not use_mongodb and results_dir != str(
@@ -122,12 +122,13 @@ def render_dashboard(run_id_filter: str):
InfiniMetrics Dashboard 用于统一展示
通信(NCCL / 集合通信)、
+ 训练(Training / 分布式训练)、
推理(直接推理 / 服务性能)、
算子(核心算子性能)、
硬件(内存带宽 / 缓存性能)
@@ -135,7 +136,9 @@ def render_dashboard(run_id_filter: str):
测试框架输出 JSON(环境 / 配置 / 标量指标) +
CSV(曲线 / 时序数据),
- Dashboard 自动加载并支持多次运行的对比分析与可视化。
+ Dashboard 自动加载并支持多次运行的
+ 性能对比、趋势分析 与
+ 可视化展示。
""",
unsafe_allow_html=True,
@@ -177,6 +180,7 @@ def _parse_time(t):
# ========== Categorize runs ==========
comm_runs = [r for r in runs if r.get("testcase", "").startswith("comm")]
infer_runs = [r for r in runs if r.get("testcase", "").startswith("infer")]
+ train_runs = [r for r in runs if r.get("testcase", "").startswith("train")]
ops_runs, hw_runs = [], []
for r in runs:
@@ -188,13 +192,14 @@ def _parse_time(t):
hw_runs.append(r)
# ========== KPI ==========
- c1, c2, c3, c4, c5, c6 = st.columns(6)
+ c1, c2, c3, c4, c5, c6, c7 = st.columns(7)
c1.metric("总测试数", total)
c2.metric("成功率", f"{(success/total*100):.1f}%")
c3.metric("通信测试", len(comm_runs))
c4.metric("推理测试", len(infer_runs))
- c5.metric("算子测试", len(ops_runs))
- c6.metric("硬件检测", len(hw_runs))
+ c5.metric("训练测试", len(train_runs))
+ c6.metric("算子测试", len(ops_runs))
+ c7.metric("硬件检测", len(hw_runs))
st.caption(f"失败测试数:{fail}")
st.caption(f"当前筛选:加速卡={','.join(selected_accs) or '全部'}")
@@ -208,8 +213,9 @@ def _latest(lst):
latest_comm = _latest(comm_runs)
latest_infer = _latest(infer_runs)
latest_ops = _latest(ops_runs)
+ latest_train = _latest(train_runs)
- colA, colB, colC = st.columns(3)
+ colA, colB, colC, colD = st.columns(4)
with colA:
st.markdown("#### 🔗 通信(最新)")
@@ -238,6 +244,17 @@ def _latest(lst):
st.write(f"- time: {latest_ops.get('time','')}")
st.write(f"- status: {'✅' if latest_ops.get('success') else '❌'}")
+ with colD:
+ st.markdown("#### 🏋️ 训练(最新)")
+ if not latest_train:
+ st.info("暂无训练结果")
+ else:
+ framework = latest_train.get("config", {}).get("framework", "unknown")
+ model = latest_train.get("config", {}).get("model", "unknown")
+ st.write(f"- 框架/模型: `{framework}/{model}`")
+ st.write(f"- time: {latest_train.get('time','')}")
+ st.write(f"- status: {'✅' if latest_train.get('success') else '❌'}")
+
st.divider()
# ========== Recent runs table ==========
@@ -294,13 +311,15 @@ def _latest(lst):
st.markdown("---")
st.markdown("### 🚀 快速导航")
- col1, col2, col3 = st.columns(3)
+ col1, col2, col3, col4 = st.columns(4)
if col1.button("🔗 通信测试分析", use_container_width=True):
st.switch_page("pages/communication.py")
if col2.button("⚡ 算子测试分析", use_container_width=True):
st.switch_page("pages/operator.py")
- if col3.button("🤖 推理测试分析", use_container_width=True):
+ if col3.button("🚀 推理测试分析", use_container_width=True):
st.switch_page("pages/inference.py")
+ if col4.button("🏋️ 训练测试分析", use_container_width=True):
+ st.switch_page("pages/training.py")
except Exception as e:
st.error(f"Dashboard 加载失败: {e}")
diff --git a/dashboard/pages/communication.py b/dashboard/pages/communication.py
index 27d9508..4685a21 100644
--- a/dashboard/pages/communication.py
+++ b/dashboard/pages/communication.py
@@ -60,7 +60,7 @@ def main():
# Status filter
show_success = st.checkbox("仅显示成功测试", value=True)
- # Apply filters
+ # Apply filter
filtered_runs = [
r
for r in comm_runs
@@ -123,6 +123,7 @@ def main():
identifier = run_info.get("path") or run_info.get("run_id")
result = st.session_state.data_loader.load_test_result(identifier)
run_info["data"] = result
+
selected_runs.append(run_info)
# Tabs for different views
@@ -183,36 +184,30 @@ def main():
st.plotly_chart(fig, use_container_width=True)
if len(selected_runs) == 1:
- st.markdown("#### 📌 核心指标(最新)")
+ st.markdown("#### 关键指标")
run = selected_runs[0]
core = extract_core_metrics(run)
- c1, c2, c3 = st.columns(3)
-
- c1.metric(
+ # First Line: numerical indicators
+ cols = st.columns(3)
+ cols[0].metric(
"峰值带宽",
- (
- f"{core['bandwidth_gbps']:.2f} GB/s"
- if core["bandwidth_gbps"]
- else "-"
- ),
+ f"{core['bandwidth_gbps']:.2f} GB/s"
+ if core["bandwidth_gbps"]
+ else "-",
)
- c2.metric(
+ cols[1].metric(
"平均延迟",
f"{core['latency_us']:.2f} μs" if core["latency_us"] else "-",
)
- c3.metric(
+ cols[2].metric(
"测试耗时",
f"{core['duration_ms']:.2f} ms" if core["duration_ms"] else "-",
)
- # Gauge charts for key metrics
- if len(selected_runs) == 1:
- st.markdown("#### 关键指标")
- run = selected_runs[0]
- col1, col2, col3 = st.columns(3)
+ cols = st.columns(3)
- with col1:
+ with cols[0]:
# Find max bandwidth
max_bw = 0
for metric in run.get("data", {}).get("metrics", []):
@@ -233,7 +228,7 @@ def main():
st.plotly_chart(fig, use_container_width=True)
break
- with col2:
+ with cols[1]:
# Find average latency
avg_lat = 0
for metric in run.get("data", {}).get("metrics", []):
@@ -254,7 +249,7 @@ def main():
st.plotly_chart(fig, use_container_width=True)
break
- with col3:
+ with cols[2]:
# Extract duration
duration = 0
for metric in run.get("data", {}).get("metrics", []):
diff --git a/dashboard/pages/inference.py b/dashboard/pages/inference.py
index 14e0741..bcc21c3 100644
--- a/dashboard/pages/inference.py
+++ b/dashboard/pages/inference.py
@@ -11,7 +11,7 @@
create_summary_table_infer,
)
-init_page("推理测试分析 | InfiniMetrics", "🤖")
+init_page("推理测试分析 | InfiniMetrics", "🚀")
def main():
@@ -180,7 +180,7 @@ def _plot_metric(metric_name_contains: str, container):
_plot_metric("infer.compute_latency", c1)
_plot_metric("infer.ttft", c2)
- _plot_metric("infer.direct_throughput", c3)
+ _plot_metric("infer.direct_throughput_tps", c3)
# ---------- Tables ----------
with tab2:
diff --git a/dashboard/pages/training.py b/dashboard/pages/training.py
new file mode 100644
index 0000000..1db0a07
--- /dev/null
+++ b/dashboard/pages/training.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python3
+"""Training tests analysis page."""
+
+import streamlit as st
+
+from common import init_page
+from components.header import render_header
+from utils.training_utils import (
+ load_training_runs,
+ filter_runs,
+ create_run_options,
+ load_selected_runs,
+ create_training_summary,
+)
+from utils.training_plots import (
+ render_performance_curves,
+ render_throughput_comparison,
+ render_data_tables,
+ render_config_details,
+)
+
+init_page("训练测试分析 | InfiniMetrics", "🏋️")
+
+
+def main():
+ render_header()
+ st.markdown("## 🏋️ 训练性能测试分析")
+
+ dl = st.session_state.data_loader
+ runs = load_training_runs(dl)
+
+ if not runs:
+ st.info("未找到训练测试结果\n请将训练测试结果放在 output/train/ 或 output/training/ 目录下")
+ return
+
+ # Sidebar Filters
+ with st.sidebar:
+ st.markdown("### 🔍 筛选条件")
+
+ frameworks = sorted(
+ {r.get("config", {}).get("framework", "unknown") for r in runs}
+ )
+ models = sorted({r.get("config", {}).get("model", "unknown") for r in runs})
+ device_counts = sorted({r.get("device_used", 1) for r in runs})
+
+ selected_fw = st.multiselect("框架", frameworks, default=frameworks)
+ selected_models = st.multiselect("模型", models, default=models)
+ selected_dev = st.multiselect("设备数", device_counts, default=device_counts)
+ only_success = st.checkbox("仅显示成功测试", value=True)
+
+ st.markdown("---")
+ st.markdown("### 📈 图表选项")
+ y_log = st.checkbox("Y轴对数刻度", value=False)
+ smoothing = st.slider("平滑窗口", 1, 50, 5, help="对曲线进行移动平均平滑")
+
+ # Apply filters
+ filtered = filter_runs(
+ runs, selected_fw, selected_models, selected_dev, only_success
+ )
+ st.caption(f"找到 {len(filtered)} 个训练测试")
+
+ if not filtered:
+ st.warning("没有符合条件的测试结果")
+ return
+
+ # Run Selection
+ options = create_run_options(filtered)
+ selected = st.multiselect(
+ "选择要分析的测试运行(可多选对比)",
+ list(options.keys()),
+ default=list(options.keys())[: min(3, len(options))],
+ )
+
+ if not selected:
+ return
+
+ # Load selected runs
+ selected_runs = load_selected_runs(dl, filtered, options, selected)
+
+ # Tabs
+ tab1, tab2, tab3, tab4 = st.tabs(["📈 性能曲线", "📊 吞吐量对比", "📋 数据表格", "🔍 详细配置"])
+
+ with tab1:
+ render_performance_curves(selected_runs, smoothing, y_log)
+ with tab2:
+ render_throughput_comparison(selected_runs)
+ with tab3:
+ render_data_tables(selected_runs)
+ with tab4:
+ render_config_details(selected_runs, create_training_summary)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dashboard/utils/data_loader.py b/dashboard/utils/data_loader.py
index b605920..f843617 100644
--- a/dashboard/utils/data_loader.py
+++ b/dashboard/utils/data_loader.py
@@ -28,7 +28,7 @@ class InfiniMetricsDataLoader:
def __init__(
self,
- results_dir: str = "../output",
+ results_dir: str = "./output",
use_mongodb: bool = False,
mongo_config=None,
fallback_to_files: bool = True,
diff --git a/dashboard/utils/data_sources.py b/dashboard/utils/data_sources.py
index f3abf21..7ce2269 100644
--- a/dashboard/utils/data_sources.py
+++ b/dashboard/utils/data_sources.py
@@ -51,7 +51,7 @@ def source_type(self) -> str:
class FileDataSource(DataSource):
"""File-based data source (reads from JSON/CSV files)."""
- def __init__(self, results_dir: str = "../output"):
+ def __init__(self, results_dir: str = "./output"):
self.results_dir = Path(results_dir)
@property
diff --git a/dashboard/utils/training_plots.py b/dashboard/utils/training_plots.py
new file mode 100644
index 0000000..a49453b
--- /dev/null
+++ b/dashboard/utils/training_plots.py
@@ -0,0 +1,261 @@
+"""Training plot functions."""
+
+import streamlit as st
+import pandas as pd
+import plotly.graph_objects as go
+import plotly.express as px
+
+from utils.training_utils import get_metric_dataframe, apply_smoothing
+from utils.visualizations import create_gauge_chart
+
+
+def render_performance_curves(selected_runs, smoothing, y_log):
+ """Render performance curves"""
+ st.markdown("### 训练指标曲线")
+
+ metrics = [
+ ("train.loss", "Loss", "损失值", ""),
+ ("train.ppl", "Perplexity", "困惑度", ""),
+ ("train.throughput", "Throughput", "吞吐量", "tokens/s/GPU"),
+ ]
+
+ cols = st.columns(3)
+
+ for idx, (metric_key, title, ylabel, unit) in enumerate(metrics):
+ with cols[idx]:
+ st.markdown(f"**{title}**")
+
+ if len(selected_runs) == 1:
+ plot_single_metric(
+ selected_runs[0], metric_key, title, ylabel, unit, smoothing, y_log
+ )
+ else:
+ plot_multi_metric_comparison(
+ selected_runs, metric_key, title, ylabel, unit, smoothing, y_log
+ )
+
+ # Memory usage (only for single run)
+ if len(selected_runs) == 1:
+ render_memory_usage(selected_runs[0])
+
+
+def plot_single_metric(run, metric_key, title, ylabel, unit, smoothing, y_log):
+ """Draw a curve for a single indicator"""
+ target_metric = get_metric_dataframe(run, metric_key)
+
+ if not target_metric:
+ st.info(f"无{title}数据")
+ return
+
+ df = target_metric["data"].copy()
+ if df.empty or len(df.columns) < 2:
+ st.info("数据为空")
+ return
+
+ df = apply_smoothing(df, smoothing)
+
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=df[df.columns[0]],
+ y=df[df.columns[1]],
+ mode="lines",
+ name=title,
+ line=dict(width=2),
+ )
+ )
+
+ fig.update_layout(
+ title=f"{title} - {run.get('config', {}).get('framework', '')}",
+ xaxis_title="Iteration",
+ yaxis_title=f"{ylabel} {unit}",
+ template="plotly_white",
+ height=350,
+ margin=dict(l=40, r=20, t=40, b=40),
+ )
+
+ if y_log:
+ fig.update_yaxes(type="log")
+
+ st.plotly_chart(fig, use_container_width=True)
+
+
+def plot_multi_metric_comparison(
+ runs, metric_key, title, ylabel, unit, smoothing, y_log
+):
+ """Comparison of multiple operation indexes"""
+ colors = px.colors.qualitative.Set2
+ fig = go.Figure()
+ found_data = False
+
+ for i, run in enumerate(runs):
+ target_metric = get_metric_dataframe(run, metric_key)
+ if not target_metric:
+ continue
+
+ df = target_metric["data"].copy()
+ if df.empty or len(df.columns) < 2:
+ continue
+
+ found_data = True
+ df = apply_smoothing(df, smoothing)
+
+ config = run.get("config", {})
+ framework = config.get("framework", "unknown")
+ model = config.get("model", "unknown")
+ device = run.get("device_used", "?")
+
+ fig.add_trace(
+ go.Scatter(
+ x=df[df.columns[0]],
+ y=df[df.columns[1]],
+ mode="lines",
+ name=f"{framework}/{model} ({device}GPU)",
+ line=dict(color=colors[i % len(colors)], width=2),
+ )
+ )
+
+ if not found_data:
+ st.info(f"无{title}数据")
+ return
+
+ fig.update_layout(
+ title=f"{title} 对比",
+ xaxis_title="Iteration",
+ yaxis_title=f"{ylabel} {unit}",
+ template="plotly_white",
+ height=350,
+ legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
+ margin=dict(l=40, r=20, t=40, b=40),
+ )
+
+ if y_log:
+ fig.update_yaxes(type="log")
+
+ st.plotly_chart(fig, use_container_width=True)
+
+
+def render_memory_usage(run):
+ """Render memory usage gauge"""
+ memory_metric = get_metric_dataframe(run, "train.peak_memory_usage")
+
+ if not memory_metric or not memory_metric.get("value"):
+ return
+
+ st.markdown("#### 💾 显存使用")
+ value = memory_metric["value"]
+ unit = memory_metric.get("unit", "GB")
+
+ # Try to get max memory from environment
+ max_value = value * 1.5
+ try:
+ env = run["data"].get("environment", {})
+ if env and "cluster" in env and len(env["cluster"]) > 0:
+ acc = env["cluster"][0].get("machine", {}).get("accelerators", [])
+ if acc and len(acc) > 0:
+ memory_per_card = float(acc[0].get("memory_gb_per_card", 80))
+ card_count = int(acc[0].get("count", 1))
+ max_value = memory_per_card * card_count
+ else:
+ st.warning("未找到加速卡信息,使用默认显存上限")
+ else:
+ st.warning("未找到环境信息,使用默认显存上限")
+ except Exception as e:
+ st.warning(f"解析环境信息失败: {e},使用默认显存上限")
+
+ col1, col2, col3 = st.columns([1, 2, 1])
+ with col2:
+ fig = create_gauge_chart(value, max_value, "峰值显存使用", "green", unit)
+ st.plotly_chart(fig, use_container_width=True)
+
+
+def render_throughput_comparison(selected_runs):
+ """Render throughput comparison"""
+ st.markdown("### 吞吐量对比")
+
+ throughput_data = []
+ for run in selected_runs:
+ target_metric = get_metric_dataframe(run, "throughput")
+ if not target_metric:
+ continue
+
+ df = target_metric["data"]
+ if df.empty or len(df.columns) < 2:
+ continue
+
+ avg_tput = df[df.columns[1]].mean()
+ peak_tput = df[df.columns[1]].max()
+
+ config = run.get("config", {})
+ throughput_data.append(
+ {
+ "运行": f"{config.get('framework', 'unknown')}/{config.get('model', 'unknown')} ({run.get('device_used', '?')}GPU)",
+ "平均吞吐量 (tokens/s/GPU)": round(avg_tput, 2),
+ "峰值吞吐量 (tokens/s/GPU)": round(peak_tput, 2),
+ }
+ )
+
+ if not throughput_data:
+ st.info("无可对比的吞吐量数据")
+ return
+
+ # Display table
+ df = pd.DataFrame(throughput_data)
+ st.dataframe(df, width="stretch", hide_index=True)
+
+ # Bar chart
+ fig = go.Figure()
+ for i, data in enumerate(throughput_data):
+ fig.add_trace(
+ go.Bar(
+ name=data["运行"],
+ x=["平均吞吐量", "峰值吞吐量"],
+ y=[data["平均吞吐量 (tokens/s/GPU)"], data["峰值吞吐量 (tokens/s/GPU)"]],
+ text=[data["平均吞吐量 (tokens/s/GPU)"], data["峰值吞吐量 (tokens/s/GPU)"]],
+ textposition="auto",
+ )
+ )
+
+ fig.update_layout(
+ title="吞吐量对比",
+ barmode="group",
+ template="plotly_white",
+ height=400,
+ yaxis_title="tokens/s/GPU",
+ )
+ st.plotly_chart(fig, use_container_width=True)
+
+
+def render_data_tables(selected_runs):
+ """Render data tables"""
+ for run in selected_runs:
+ with st.expander(f"{run.get('run_id', 'Unknown')} - 原始数据"):
+ for metric in run["data"].get("metrics", []):
+ if metric.get("data") is not None:
+ df = metric["data"].copy()
+ st.markdown(f"**{metric.get('name', 'Unknown')}**")
+ if len(df.columns) == 2:
+ df.columns = ["Iteration", metric.get("name", "Value")]
+ st.dataframe(df, width="stretch", hide_index=True)
+
+
+def render_config_details(selected_runs, summary_func):
+ """Render configuration details"""
+ for run in selected_runs:
+ with st.expander(f"{run.get('run_id', 'Unknown')} - 配置与环境"):
+ summary_df = summary_func(run["data"])
+ if not summary_df.empty:
+ st.markdown("**配置摘要**")
+ st.dataframe(summary_df, width="stretch", hide_index=True)
+
+ col1, col2 = st.columns(2)
+ with col1:
+ st.markdown("**完整配置**")
+ st.json(run["data"].get("config", {}))
+ with col2:
+ st.markdown("**环境信息**")
+ st.json(run["data"].get("environment", {}))
+
+ if run["data"].get("resolved"):
+ st.markdown("**解析信息**")
+ st.json(run["data"].get("resolved", {}))
diff --git a/dashboard/utils/training_utils.py b/dashboard/utils/training_utils.py
new file mode 100644
index 0000000..975777f
--- /dev/null
+++ b/dashboard/utils/training_utils.py
@@ -0,0 +1,163 @@
+"""Training page utilities."""
+
+import streamlit as st
+import pandas as pd
+import plotly.graph_objects as go
+import plotly.express as px
+
+
+def load_training_runs(data_loader):
+ """Load all training-related test runs"""
+ runs = data_loader.list_test_runs("train")
+
+ if not runs:
+ all_runs = data_loader.list_test_runs()
+ runs = [
+ r
+ for r in all_runs
+ if any(
+ keyword in str(r.get("path", "")).lower()
+ or keyword in r.get("testcase", "").lower()
+ for keyword in [
+ "/train/",
+ "/training/",
+ "train.",
+ "megatron",
+ "lora",
+ "sft",
+ ]
+ )
+ ]
+ return runs
+
+
+def filter_runs(runs, selected_fw, selected_models, selected_dev, only_success):
+ """Apply filters to runs"""
+ return [
+ r
+ for r in runs
+ if (
+ not selected_fw
+ or r.get("config", {}).get("framework", "unknown") in selected_fw
+ )
+ and (
+ not selected_models
+ or r.get("config", {}).get("model", "unknown") in selected_models
+ )
+ and (not selected_dev or r.get("device_used", 1) in selected_dev)
+ and (not only_success or r.get("success", False))
+ ]
+
+
+def create_run_options(runs):
+ """Create run selection options"""
+ return {
+ f"{r.get('config', {}).get('framework', 'unknown')}/"
+ f"{r.get('config', {}).get('model', 'unknown')} | "
+ f"{r.get('device_used', '?')}GPU | "
+ f"{r.get('time', '')[:16]}": i
+ for i, r in enumerate(runs)
+ }
+
+
+def load_selected_runs(data_loader, filtered_runs, options, selected_labels):
+ """Load the selected test run"""
+ selected_runs = []
+ for label in selected_labels:
+ idx = options[label]
+ run_info = filtered_runs[idx].copy()
+ run_info["data"] = data_loader.load_test_result(run_info["path"])
+ selected_runs.append(run_info)
+ return selected_runs
+
+
+def get_metric_dataframe(run, metric_key):
+ """Get metric dataframe"""
+ metrics = run["data"].get("metrics", [])
+ return next(
+ (
+ m
+ for m in metrics
+ if metric_key in m.get("name", "") and m.get("data") is not None
+ ),
+ None,
+ )
+
+
+def apply_smoothing(df, smoothing):
+ """Apply smoothing to dataframe"""
+ if smoothing > 1 and len(df) > smoothing:
+ df = df.copy()
+ df.iloc[:, 1] = df.iloc[:, 1].rolling(window=smoothing, min_periods=1).mean()
+ return df
+
+
+def create_training_summary(test_result: dict) -> pd.DataFrame:
+ """Create configuration summary for training tests"""
+ rows = []
+
+ # Environment info
+ try:
+ env = test_result.get("environment", {})
+ if "cluster" in env and len(env["cluster"]) > 0:
+ acc = env["cluster"][0]["machine"]["accelerators"][0]
+ rows.extend(
+ [
+ {"指标": "加速卡", "数值": str(acc.get("model", "Unknown"))},
+ {"指标": "卡数", "数值": str(acc.get("count", "Unknown"))},
+ {"指标": "显存/卡", "数值": f"{acc.get('memory_gb_per_card','?')} GB"},
+ ]
+ )
+ except Exception as e:
+ st.warning(f"解析环境信息失败: {e}")
+
+ # Config info
+ cfg = test_result.get("config", {})
+ train_args = cfg.get("train_args", {})
+
+ rows.extend(
+ [
+ {"指标": "框架", "数值": str(cfg.get("framework", "unknown"))},
+ {"指标": "模型", "数值": str(cfg.get("model", "unknown"))},
+ ]
+ )
+
+ # Parallel config
+ parallel = train_args.get("parallel", {})
+ rows.append(
+ {
+ "指标": "并行配置",
+ "数值": f"DP={parallel.get('dp', 1)}, TP={parallel.get('tp', 1)}, PP={parallel.get('pp', 1)}",
+ }
+ )
+
+ # Other configs
+ config_items = [
+ ("MBS/GBS", f"{train_args.get('mbs', '?')}/{train_args.get('gbs', '?')}"),
+ ("序列长度", str(train_args.get("seq_len", "?"))),
+ ("隐藏层大小", str(train_args.get("hidden_size", "?"))),
+ ("层数", str(train_args.get("num_layers", "?"))),
+ ("精度", str(train_args.get("precision", "?"))),
+ ("预热迭代", str(cfg.get("warmup_iterations", "?"))),
+ ("训练迭代", str(train_args.get("train_iters", "?"))),
+ ]
+
+ for label, value in config_items:
+ rows.append({"指标": label, "数值": value})
+
+ # Scalar metrics
+ for m in test_result.get("metrics", []):
+ if m.get("type") == "scalar":
+ value = m.get("value", "")
+ unit = m.get("unit", "")
+ rows.append(
+ {
+ "指标": str(m.get("name")),
+ "数值": f"{value} {unit}" if unit and value != "" else str(value),
+ }
+ )
+
+ df = pd.DataFrame(rows)
+ if not df.empty:
+ df["数值"] = df["数值"].astype(str)
+ return df
diff --git a/dashboard/utils/visualizations.py b/dashboard/utils/visualizations.py
index 30c064a..9846809 100644
--- a/dashboard/utils/visualizations.py
+++ b/dashboard/utils/visualizations.py
@@ -275,18 +275,36 @@ def create_summary_table(test_result: Dict[str, Any]) -> pd.DataFrame:
def create_gauge_chart(
- value: float, max_value: float, title: str, color: str = "blue", unit: str = ""
+ value: float,
+ max_value: float,
+ title: str,
+ color: str = "blue",
+ unit: str = "",
+ decimals: Optional[int] = None, # optional
) -> go.Figure:
"""Create a gauge chart for single metric visualization."""
+
+ if decimals is None:
+ if value < 10:
+ decimals = 2
+ elif value < 100:
+ decimals = 1
+ else:
+ decimals = 0
+
fig = go.Figure(
go.Indicator(
mode="gauge+number",
value=value,
domain={"x": [0, 1], "y": [0.05, 0.85]},
title={"text": title, "font": {"size": 18}},
- number={"suffix": f" {unit}", "font": {"size": 36}},
+ number={
+ "suffix": f" {unit}",
+ "font": {"size": 36},
+ "valueformat": f".{decimals}f",
+ },
gauge={
- "axis": {"range": [0, max_value]},
+ "axis": {"range": [0, max_value], "tickformat": f".{decimals}f"},
"bar": {"color": color},
"steps": [
{"range": [0, max_value * 0.6], "color": "lightgray"},