Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def initResultsPageConfig(st):
st.set_page_config(
page_title=PAGE_TITLE,
page_icon=FAVICON,
# layout="wide",
layout="wide",
# initial_sidebar_state="collapsed",
)

Expand Down
117 changes: 117 additions & 0 deletions vectordb_bench/frontend/components/qps_recall/charts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from vectordb_bench.frontend.components.check_results.expanderStyle import (
initMainExpanderStyle,
)
from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map
from vectordb_bench.frontend.config.styles import *
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import matplotlib.pyplot as plt


def drawCharts(st, allData, caseNames: list[str]):
initMainExpanderStyle(st)
for caseName in caseNames:
chartContainer = st.expander(caseName, True)
data = [data for data in allData if data["case_name"] == caseName]
drawChart(data, chartContainer, key_prefix=caseName)


def drawChart(data, st, key_prefix: str):
metricsSet = set()
for d in data:
metricsSet = metricsSet.union(d["metricsSet"])
showlineMetrics = [metric for metric in metric_order[:2] if metric in metricsSet]

if showlineMetrics:
metric = showlineMetrics[0]
key = f"{key_prefix}-{metric}"
drawlinechart(st, data, metric, key=key)


def drawBestperformance(data, y, group):
all_filter_points = []
data = pd.DataFrame(data)
grouped = data.groupby(group)
for name, group_df in grouped:
filter_points = []
current_start = 0
for _ in range(len(group_df)):
if current_start >= len(group_df):
break
max_index = group_df[y].iloc[current_start:].idxmax()
filter_points.append(group_df.loc[max_index])

current_start = group_df.index.get_loc(max_index) + 1
all_filter_points.extend(filter_points)

all_filter_df = pd.DataFrame(all_filter_points)
remaining_df = data[~data.isin(all_filter_df).any(axis=1)]
new_data = all_filter_df.to_dict(orient="records")
remain_data = remaining_df.to_dict(orient="records")
return new_data, remain_data


def drawlinechart(st, data: list[object], metric, key: str):
unit = metric_unit_map.get(metric, "")
minV = min([d.get(metric, 0) for d in data])
maxV = max([d.get(metric, 0) for d in data])
padding = maxV - minV
rangeV = [
minV - padding * 0.1,
maxV + padding * 0.1,
]
x = "recall"
xrange = [0.8, 1.01]
y = "qps"
yrange = rangeV
data.sort(key=lambda a: a[x])
group = "db_name"
new_data, new_remain_data = drawBestperformance(data, y, group)
unique_db_names = list(set(item["db_name"] for item in new_data + new_remain_data))

colors = plt.cm.get_cmap("tab10", len(unique_db_names))

color_map = {
db: f"rgb({int(colors(i)[0] * 255)}, {int(colors(i)[1] * 255)}, {int(colors(i)[2] * 255)})"
for i, db in enumerate(unique_db_names)
}

fig = go.Figure()

new_data_df = pd.DataFrame(new_data)

for db in unique_db_names:
db_data = new_data_df[new_data_df["db_name"] == db]
fig.add_trace(
go.Scatter(
x=db_data["recall"],
y=db_data["qps"],
mode="lines+markers",
name=db,
line=dict(color=color_map[db]),
marker=dict(color=color_map[db]),
showlegend=True,
)
)

for item in new_remain_data:
fig.add_trace(
go.Scatter(
x=[item["recall"]],
y=[item["qps"]],
mode="markers",
name=item["db_name"],
marker=dict(color=color_map[item["db_name"]]),
showlegend=False,
)
)

fig.update_xaxes(range=xrange)
fig.update_yaxes(range=yrange)
fig.update_traces(textposition="bottom right", texttemplate="%{y:,.4~r}" + unit)
fig.update_layout(
margin=dict(l=0, r=0, t=40, b=0, pad=8),
legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
)
st.plotly_chart(fig, use_container_width=True, key=key)
58 changes: 58 additions & 0 deletions vectordb_bench/frontend/components/qps_recall/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from collections import defaultdict
from dataclasses import asdict
from vectordb_bench.backend.filter import FilterOp
from vectordb_bench.frontend.components.check_results.data import getFilterTasks
from vectordb_bench.frontend.components.check_results.filters import getShowDbsAndCases, getshownResults
from vectordb_bench.models import CaseResult, ResultLabel, TestResult


def getshownData(st, results: list[TestResult], filter_type: FilterOp = FilterOp.NonFilter, **kwargs):
# hide the nav
st.markdown(
"<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
unsafe_allow_html=True,
)
st.header("Filters")
shownResults = getshownResults(st, results, **kwargs)
showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type)
shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
return shownData, failedTasks, showCaseNames


def getChartData(
tasks: list[CaseResult],
dbNames: list[str],
caseNames: list[str],
):
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
failedTasks = defaultdict(lambda: defaultdict(str))
nonemergedTasks = []
for task in filterTasks:
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
version = task.task_config.db_config.version or ""
case = task.task_config.case_config.case
case_name = case.name
dataset_name = case.dataset.data.full_name
filter_rate = case.filter_rate
metrics = asdict(task.metrics)
label = task.label
if label == ResultLabel.NORMAL:
nonemergedTasks.append(
{
"db_name": db_name,
"db": db,
"db_label": db_label,
"dataset_name": dataset_name,
"filter_rate": filter_rate,
"version": version,
"case_name": case_name,
"metricsSet": set(metrics.keys()),
**metrics,
}
)
else:
failedTasks[case_name][db_name] = label

return nonemergedTasks, failedTasks
59 changes: 59 additions & 0 deletions vectordb_bench/frontend/pages/qps_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import streamlit as st
from vectordb_bench.frontend.components.check_results.footer import footer
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
from vectordb_bench.frontend.components.check_results.nav import (
NavToQuriesPerDollar,
NavToRunTest,
NavToPages,
)
from vectordb_bench.frontend.components.qps_recall.charts import drawCharts
from vectordb_bench.frontend.components.qps_recall.data import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults

from vectordb_bench.frontend.config.styles import FAVICON
from vectordb_bench.interface import benchmark_runner


def main():
# set page config
st.set_page_config(
page_title="Label Filter",
page_icon=FAVICON,
layout="wide",
# initial_sidebar_state="collapsed",
)

# header
drawHeaderIcon(st)

# navigate
NavToPages(st)

allResults = benchmark_runner.get_results()

st.title("Vector Database Benchmark (Qps & Recall)")

# results selector and filter
resultSelectorContainer = st.sidebar.container()
shownData, failedTasks, showCaseNames = getshownData(resultSelectorContainer, allResults)

resultSelectorContainer.divider()

# nav
navContainer = st.sidebar.container()
NavToRunTest(navContainer)
NavToQuriesPerDollar(navContainer)

# save or share
resultesContainer = st.sidebar.container()
getResults(resultesContainer, "vectordb_bench")

# charts
drawCharts(st, shownData, showCaseNames)

# footer
footer(st.container())


if __name__ == "__main__":
main()
Loading