Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/aiconfigurator/sdk/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,16 @@ class BlockConfig:
]


class SOLMode(Enum):
class DatabaseMode(Enum):
"""
SOL mode for database.
Database mode.
"""

NON_SOL = 0
SOL = 1
SOL_FULL = 2
SILICON = 0 # default mode using silicon data
HYBRID = 1 # use silicon data when available, otherwise use SOL+empirical factor
EMPIRICAL = 2 # SOL+empirical factor
SOL = 3 # Provide SOL time only
SOL_FULL = 4 # Provide SOL time and details


class BackendName(Enum):
Expand Down
1,076 changes: 768 additions & 308 deletions src/aiconfigurator/sdk/perf_database.py

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions src/aiconfigurator/webapp/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,20 @@ def create_system_config(app_config):
backend = gr.Dropdown(choices=backend_choices, label="Backend", value=default_backend, interactive=True)
version = gr.Dropdown(choices=version_choices, label="Version", value=default_version, interactive=True)

sol_mode = gr.Checkbox(label="SOL Mode", value=False, interactive=True, visible=app_config["experimental"])
database_mode = gr.Dropdown(
choices=[
common.DatabaseMode.SILICON.name,
common.DatabaseMode.HYBRID.name,
common.DatabaseMode.EMPIRICAL.name,
common.DatabaseMode.SOL.name,
],
label="Database Mode",
value=common.DatabaseMode.SILICON.name,
interactive=True,
visible=app_config["experimental"],
)

return {"system": system, "backend": backend, "version": version, "sol_mode": sol_mode}
return {"system": system, "backend": backend, "version": version, "database_mode": database_mode}


def create_model_quant_config(app_config):
Expand Down
36 changes: 18 additions & 18 deletions src/aiconfigurator/webapp/events/event_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def run_estimation_static(
system_name,
backend_name,
version,
sol_mode,
database_mode,
batch_size,
isl,
osl,
Expand Down Expand Up @@ -184,7 +184,7 @@ def run_estimation_static(
try:
database = copy.deepcopy(get_database(system_name, backend_name, version))
assert database is not None
database.set_default_sol_mode(common.SOLMode(int(sol_mode)))
database.set_default_database_mode(common.DatabaseMode[database_mode])
nextn_accept_rates = [float(x) for x in nextn_accept_rates.split(",")]
model_config = config.ModelConfig(
tp_size=tp_size,
Expand Down Expand Up @@ -249,7 +249,7 @@ def run_estimation_agg(
system_name,
backend_name,
version,
sol_mode,
database_mode,
isl,
osl,
prefix,
Expand Down Expand Up @@ -280,7 +280,7 @@ def run_estimation_agg(
try:
database = get_database(system_name, backend_name, version)
assert database is not None
database.set_default_sol_mode(common.SOLMode(int(sol_mode)))
database.set_default_database_mode(common.DatabaseMode[database_mode])
nextn_accept_rates = [float(x) for x in nextn_accept_rates.split(",")]
model_config = config.ModelConfig(
tp_size=tp_size,
Expand Down Expand Up @@ -359,7 +359,7 @@ def run_estimation_agg_pareto(
system_name,
backend_name,
version,
sol_mode,
database_mode,
isl,
osl,
prefix,
Expand Down Expand Up @@ -391,7 +391,7 @@ def run_estimation_agg_pareto(
try:
database = copy.deepcopy(get_database(system_name, backend_name, version))
assert database is not None
database.set_default_sol_mode(common.SOLMode(int(sol_mode)))
database.set_default_database_mode(common.DatabaseMode[database_mode])
nextn_accept_rates = [float(x) for x in nextn_accept_rates.split(",")]
model_config = config.ModelConfig(
gemm_quant_mode=common.GEMMQuantMode[gemm_quant_mode],
Expand Down Expand Up @@ -499,7 +499,7 @@ def run_estimation_disagg_pareto(
prefill_system_name,
prefill_backend_name,
prefill_version,
prefill_sol_mode,
prefill_database_mode,
prefill_num_worker,
prefill_num_gpus,
prefill_tp_size,
Expand All @@ -516,7 +516,7 @@ def run_estimation_disagg_pareto(
decode_system_name,
decode_backend_name,
decode_version,
decode_sol_mode,
decode_database_mode,
decode_num_worker,
decode_num_gpus,
decode_tp_size,
Expand Down Expand Up @@ -553,8 +553,8 @@ def run_estimation_disagg_pareto(
decode_database = copy.deepcopy(get_database(decode_system_name, decode_backend_name, decode_version))
assert prefill_database is not None
assert decode_database is not None
prefill_database.set_default_sol_mode(common.SOLMode(int(prefill_sol_mode)))
decode_database.set_default_sol_mode(common.SOLMode(int(decode_sol_mode)))
prefill_database.set_default_database_mode(common.DatabaseMode[prefill_database_mode])
decode_database.set_default_database_mode(common.DatabaseMode[prefill_database_mode])
nextn_accept_rates = [float(x) for x in nextn_accept_rates.split(",")]
prefill_model_config = config.ModelConfig(
tp_size=prefill_tp_size,
Expand Down Expand Up @@ -689,10 +689,10 @@ def run_estimation_disagg_pareto(
title = (
f"{model_name}_isl{runtime_config.isl}_osl{runtime_config.osl}_prefix{runtime_config.prefix}_ttft"
f"{runtime_config.ttft}_prefill_{prefill_system_name}_{prefill_backend_name}_"
f"{prefill_version}_{prefill_sol_mode}_{prefill_gemm_quant_mode}_"
f"{prefill_version}_{prefill_database_mode}_{prefill_gemm_quant_mode}_"
f"{prefill_kvcache_quant_mode}_{prefill_fmha_quant_mode}_{prefill_moe_quant_mode}_"
f"{prefill_comm_quant_mode}_decode_{decode_system_name}_{decode_backend_name}_"
f"{decode_version}_{decode_sol_mode}_{decode_gemm_quant_mode}_"
f"{decode_version}_{decode_database_mode}_{decode_gemm_quant_mode}_"
f"{decode_kvcache_quant_mode}_{decode_fmha_quant_mode}_{decode_moe_quant_mode}_"
f"{decode_comm_quant_mode}_Disagg_Pareto"
)
Expand Down Expand Up @@ -733,7 +733,7 @@ def run_estimation_disagg_pd_ratio(
prefill_system_name,
prefill_backend_name,
prefill_version,
prefill_sol_mode,
prefill_database_mode,
prefill_tp_size,
prefill_pp_size,
prefill_dp_size,
Expand All @@ -747,7 +747,7 @@ def run_estimation_disagg_pd_ratio(
decode_system_name,
decode_backend_name,
decode_version,
decode_sol_mode,
decode_database_mode,
decode_tp_size,
decode_pp_size,
decode_dp_size,
Expand Down Expand Up @@ -868,7 +868,7 @@ def create_scatter_plot(df, x_col, y_col, target_x, x_label, title):
get_database(prefill_system_name, prefill_backend_name, prefill_version)
)
assert prefill_database is not None
prefill_database.set_default_sol_mode(common.SOLMode(int(prefill_sol_mode)))
prefill_database.set_default_database_mode(common.DatabaseMode[prefill_database_mode])
prefill_backend = get_backend(prefill_backend_name)
prefill_session = InferenceSession(prefill_model, prefill_database, prefill_backend)
prefill_results_df = pd.DataFrame(columns=common.ColumnsStatic)
Expand All @@ -886,7 +886,7 @@ def create_scatter_plot(df, x_col, y_col, target_x, x_label, title):
prefill_results_df = prefill_results_df.reset_index(drop=True).reset_index()
title = (
f"{model_name}_isl{isl}_osl{osl}_prefix{prefix}_prefill_{prefill_system_name}_"
f"{prefill_backend_name}_{prefill_version}_{prefill_sol_mode}_"
f"{prefill_backend_name}_{prefill_version}_{prefill_database_mode}_"
f"{prefill_gemm_quant_mode}_{prefill_kvcache_quant_mode}_"
f"{prefill_fmha_quant_mode}_{prefill_moe_quant_mode}_{prefill_comm_quant_mode}_"
"Throughput"
Expand All @@ -904,7 +904,7 @@ def create_scatter_plot(df, x_col, y_col, target_x, x_label, title):
decode_model = get_model(model_name, decode_model_config, decode_backend_name)
decode_database = copy.deepcopy(get_database(decode_system_name, decode_backend_name, decode_version))
assert decode_database is not None
decode_database.set_default_sol_mode(common.SOLMode(int(decode_sol_mode)))
decode_database.set_default_database_mode(common.DatabaseMode[decode_database_mode])
decode_backend = get_backend(decode_backend_name)
decode_session = InferenceSession(decode_model, decode_database, decode_backend)
decode_results_df = pd.DataFrame(columns=common.ColumnsStatic)
Expand Down Expand Up @@ -934,7 +934,7 @@ def create_scatter_plot(df, x_col, y_col, target_x, x_label, title):
decode_results_df = decode_results_df.reset_index(drop=True).reset_index()
title = (
f"{model_name}_isl{isl}_osl{osl}_decode_{decode_system_name}_"
f"{decode_backend_name}_{decode_version}_{decode_sol_mode}_"
f"{decode_backend_name}_{decode_version}_{decode_database_mode}_"
f"{decode_gemm_quant_mode}_{decode_kvcache_quant_mode}_"
f"{decode_fmha_quant_mode}_{decode_moe_quant_mode}_{decode_comm_quant_mode}_"
"Throughput"
Expand Down
14 changes: 7 additions & 7 deletions src/aiconfigurator/webapp/events/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setup_static_events(components):
components["model_system_components"]["system"],
components["model_system_components"]["backend"],
components["model_system_components"]["version"],
components["model_system_components"]["sol_mode"],
components["model_system_components"]["database_mode"],
components["runtime_config_components"]["batch_size"],
components["runtime_config_components"]["isl"],
components["runtime_config_components"]["osl"],
Expand Down Expand Up @@ -75,7 +75,7 @@ def setup_agg_events(components):
components["model_system_components"]["system"],
components["model_system_components"]["backend"],
components["model_system_components"]["version"],
components["model_system_components"]["sol_mode"],
components["model_system_components"]["database_mode"],
components["runtime_config_components"]["isl"],
components["runtime_config_components"]["osl"],
components["runtime_config_components"]["prefix"],
Expand Down Expand Up @@ -125,7 +125,7 @@ def setup_agg_pareto_events(components):
components["model_system_components"]["system"],
components["model_system_components"]["backend"],
components["model_system_components"]["version"],
components["model_system_components"]["sol_mode"],
components["model_system_components"]["database_mode"],
components["runtime_config_components"]["isl"],
components["runtime_config_components"]["osl"],
components["runtime_config_components"]["prefix"],
Expand Down Expand Up @@ -187,7 +187,7 @@ def setup_disagg_pareto_events(components):
components["prefill_model_system_components"]["system"], # prefill
components["prefill_model_system_components"]["backend"],
components["prefill_model_system_components"]["version"],
components["prefill_model_system_components"]["sol_mode"],
components["prefill_model_system_components"]["database_mode"],
components["prefill_model_parallel_components"]["num_worker"],
components["prefill_model_parallel_components"]["num_gpus"],
components["prefill_model_parallel_components"]["tp_size"],
Expand All @@ -204,7 +204,7 @@ def setup_disagg_pareto_events(components):
components["decode_model_system_components"]["system"], # decode
components["decode_model_system_components"]["backend"],
components["decode_model_system_components"]["version"],
components["decode_model_system_components"]["sol_mode"],
components["decode_model_system_components"]["database_mode"],
components["decode_model_parallel_components"]["num_worker"],
components["decode_model_parallel_components"]["num_gpus"],
components["decode_model_parallel_components"]["tp_size"],
Expand Down Expand Up @@ -280,7 +280,7 @@ def setup_disagg_pd_ratio_events(components):
components["prefill_model_system_components"]["system"], # prefill
components["prefill_model_system_components"]["backend"],
components["prefill_model_system_components"]["version"],
components["prefill_model_system_components"]["sol_mode"],
components["prefill_model_system_components"]["database_mode"],
components["prefill_model_parallel_components"]["tp_size"],
components["prefill_model_parallel_components"]["pp_size"],
components["prefill_model_parallel_components"]["dp_size"],
Expand All @@ -294,7 +294,7 @@ def setup_disagg_pd_ratio_events(components):
components["decode_model_system_components"]["system"], # decode
components["decode_model_system_components"]["backend"],
components["decode_model_system_components"]["version"],
components["decode_model_system_components"]["sol_mode"],
components["decode_model_system_components"]["database_mode"],
components["decode_model_parallel_components"]["tp_size"],
components["decode_model_parallel_components"]["pp_size"],
components["decode_model_parallel_components"]["dp_size"],
Expand Down
2 changes: 1 addition & 1 deletion tests/sdk/database/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def patch_all_loaders_and_yaml(request, monkeypatch):
# These two values are used in many "SOL"-mode formulas:
"float16_tc_flops": 1_000.0,
"mem_bw": 100.0,
# For query_nccl NON-SOL branch:
# For query_nccl SILICON branch:
"mem_empirical_constant_latency": 1.0,
},
"node": {
Expand Down
Loading