Skip to content

Commit 4a951c4

Browse files
committed
Fixes from reviews
1 parent b3e167f commit 4a951c4

File tree

9 files changed

+82
-48
lines changed

9 files changed

+82
-48
lines changed

src/guidellm/benchmark/benchmarker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import uuid
1414
from abc import ABC
1515
from collections.abc import AsyncIterator, Iterable
16-
from typing import Any, Generic
16+
from typing import Generic
1717

1818
from guidellm.benchmark.profile import Profile
1919
from guidellm.benchmark.progress import BenchmarkerProgress

src/guidellm/benchmark/entrypoints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ async def resolve_request_loader(
309309

310310
async def resolve_profile(
311311
profile: StrategyType | ProfileType | Profile,
312-
rate: float | list[float] | None,
312+
rate: list[float] | None,
313313
random_seed: int,
314314
constraints: MutableMapping[str, ConstraintInitializer | Any],
315315
max_seconds: int | float | None,
@@ -355,10 +355,9 @@ async def resolve_profile(
355355
if val is not None:
356356
constraints[key] = val
357357
if not isinstance(profile, Profile):
358-
rate_list: list[float] | None = [rate] if isinstance(rate, float) else rate
359358
profile = Profile.create(
360359
rate_type=profile,
361-
rate=rate_list,
360+
rate=rate,
362361
random_seed=random_seed,
363362
constraints={**constraints},
364363
)

src/guidellm/benchmark/outputs/console.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def add_stats(
9494
precision: int = 1,
9595
):
9696
"""
97-
Add statistical summary columns (mean and standard deviation) for a metric.
97+
Add statistical summary columns (mean and p95) for a metric.
9898
99-
Creates paired mean/stddev columns automatically and appends values from the
99+
Creates paired mean/p95 columns automatically and appends values from the
100100
specified status category of the distribution summary.
101101
102102
:param stats: Distribution summary containing status-specific statistics
@@ -111,16 +111,16 @@ def add_stats(
111111
self[f"{key}_mean"] = ConsoleTableColumn(
112112
group=group, name=name, units="Mean", precision=precision
113113
)
114-
self[f"{key}_stddev"] = ConsoleTableColumn(
115-
group=group, name=name, units="Std", precision=precision
114+
self[f"{key}_p95"] = ConsoleTableColumn(
115+
group=group, name=name, units="p95", precision=precision
116116
)
117117

118118
status_stats: DistributionSummary | None = (
119119
getattr(stats, status) if stats else None
120120
)
121121
self[f"{key}_mean"].values.append(status_stats.mean if status_stats else None)
122-
self[f"{key}_stddev"].values.append(
123-
status_stats.std_dev if status_stats else None
122+
self[f"{key}_p95"].values.append(
123+
status_stats.percentiles.p95 if status_stats else None
124124
)
125125

126126
def get_table_data(self) -> tuple[list[list[str]], list[list[str]]]:

src/guidellm/benchmark/progress.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,9 @@ class _GenerativeProgressTaskState:
348348
request_concurrency: float = 0.0
349349
requests_per_second: float = 0.0
350350
request_latency: float = 0.0
351-
output_tokens: int = 0
351+
output_tokens: float = 0
352352
output_tokens_rate: float = 0.0
353-
prompt_tokens: int = 0
353+
prompt_tokens: float = 0
354354
total_tokens_rate: float = 0.0
355355
time_to_first_token: float = 0.0
356356
inter_token_latency: float = 0.0
@@ -588,13 +588,9 @@ def update(
588588
request_latency=accumulator.completed_metrics.request_latency.mean,
589589
)
590590
self._update_token_stats(
591-
output_tokens=int(
592-
accumulator.completed_metrics.total_tokens.value_sum or 0
593-
),
591+
output_tokens=accumulator.completed_metrics.total_tokens.mean,
594592
output_tokens_rate=accumulator.completed_metrics.output_tokens.rate_per_second,
595-
prompt_tokens=int(
596-
accumulator.completed_metrics.input_tokens.value_sum or 0
597-
),
593+
prompt_tokens=accumulator.completed_metrics.input_tokens.mean,
598594
total_tokens_rate=accumulator.completed_metrics.total_tokens.rate_per_second,
599595
time_to_first_token=accumulator.completed_metrics.time_to_first_token_ms.mean,
600596
inter_token_latency=accumulator.completed_metrics.inter_token_latency_ms.mean,
@@ -621,13 +617,9 @@ def complete(self, benchmark: GenerativeBenchmark):
621617
request_latency=benchmark.metrics.request_latency.successful.mean,
622618
)
623619
self._update_token_stats(
624-
output_tokens=int(
625-
benchmark.metrics.output_token_count.successful.mean or 0
626-
),
620+
output_tokens=benchmark.metrics.output_token_count.successful.mean,
627621
output_tokens_rate=benchmark.metrics.output_tokens_per_second.successful.mean,
628-
prompt_tokens=int(
629-
benchmark.metrics.prompt_token_count.successful.mean or 0
630-
),
622+
prompt_tokens=benchmark.metrics.prompt_token_count.successful.mean,
631623
total_tokens_rate=benchmark.metrics.tokens_per_second.successful.mean,
632624
time_to_first_token=(
633625
benchmark.metrics.time_to_first_token_ms.successful.mean
@@ -682,9 +674,9 @@ def _update_request_stats(
682674

683675
def _update_token_stats(
684676
self,
685-
output_tokens: int | None = None,
677+
output_tokens: float | None = None,
686678
output_tokens_rate: float | None = None,
687-
prompt_tokens: int | None = None,
679+
prompt_tokens: float | None = None,
688680
total_tokens_rate: float | None = None,
689681
time_to_first_token: float | None = None,
690682
inter_token_latency: float | None = None,

src/guidellm/benchmark/schemas/generative/entrypoints.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
from typing import Any, Literal
1818

1919
import yaml
20-
from pydantic import ConfigDict, Field, model_serializer
20+
from pydantic import (
21+
AliasChoices,
22+
AliasGenerator,
23+
ConfigDict,
24+
Field,
25+
ValidationError,
26+
ValidatorFunctionWrapHandler,
27+
field_validator,
28+
model_serializer,
29+
)
2130
from torch.utils.data import Sampler
2231
from transformers import PreTrainedTokenizerBase
2332

@@ -101,9 +110,8 @@ def create(
101110
scenario_data = scenario_data["args"]
102111
constructor_kwargs.update(scenario_data)
103112

104-
for key, value in kwargs.items():
105-
if value != cls.get_default(key):
106-
constructor_kwargs[key] = value
113+
# Apply overrides from kwargs
114+
constructor_kwargs.update(kwargs)
107115

108116
return cls.model_validate(constructor_kwargs)
109117

@@ -138,6 +146,14 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
138146
use_enum_values=True,
139147
from_attributes=True,
140148
arbitrary_types_allowed=True,
149+
validate_by_alias=True,
150+
validate_by_name=True,
151+
alias_generator=AliasGenerator(
152+
# Support field names with hyphens
153+
validation_alias=lambda field_name: AliasChoices(
154+
field_name, field_name.replace("_", "-")
155+
),
156+
),
141157
)
142158

143159
# Required
@@ -151,7 +167,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
151167
profile: StrategyType | ProfileType | Profile = Field(
152168
default="sweep", description="Benchmark profile or scheduling strategy type"
153169
)
154-
rate: float | list[float] | None = Field(
170+
rate: list[float] | None = Field(
155171
default=None, description="Request rate(s) for rate-based scheduling"
156172
)
157173
# Backend configuration
@@ -187,6 +203,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
187203
data_request_formatter: RequestFormatter | dict[str, str] | str = Field(
188204
default="chat_completions",
189205
description="Request formatting preprocessor or template name",
206+
validation_alias=AliasChoices(
207+
"data_request_formatter",
208+
"data-request-formatter",
209+
"request_type",
210+
"request-type",
211+
),
190212
)
191213
data_collator: Callable | Literal["generative"] | None = Field(
192214
default="generative", description="Data collator for batch processing"
@@ -243,6 +265,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
243265
default=None, description="Maximum global error rate (0-1) before stopping"
244266
)
245267

268+
@field_validator("data", "data_args", "rate", mode="wrap")
269+
@classmethod
270+
def single_to_list(
271+
cls, value: Any, handler: ValidatorFunctionWrapHandler
272+
) -> list[Any]:
273+
"""
274+
Ensures field is always a list.
275+
276+
:param value: Input value for the 'data' field
277+
:return: List of data sources
278+
"""
279+
try:
280+
return handler(value)
281+
except ValidationError as err:
282+
# If validation fails, try wrapping the value in a list
283+
if err.errors()[0]["type"] == "list_type":
284+
return handler([value])
285+
else:
286+
raise
287+
246288
@model_serializer
247289
def serialize_model(self) -> dict[str, Any]:
248290
"""

src/guidellm/scheduler/worker_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ async def create_processes(self):
228228

229229
worker = WorkerProcess[RequestT, ResponseT](
230230
worker_index=rank,
231-
messaging=self.messaging.create_worker_copy(
231+
messaging=self.messaging.create_worker_copy( # type: ignore[arg-type]
232232
worker_index=rank,
233233
max_buffer_send_size=None,
234234
max_buffer_receive_size=per_proc_max_buffer_size,
235-
), # The non-group worker lacks the SchedulerState type. Type err.
235+
),
236236
backend=self.backend,
237237
strategy=self.strategy,
238238
async_limit=async_limit,

src/guidellm/schemas/request_stats.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,10 @@ def prompt_tokens_timing(self) -> tuple[float, float] | None:
266266
# no end time, can't compute
267267
return None
268268

269-
return [
270-
(
271-
self.first_token_iteration or self.request_end_time,
272-
self.prompt_tokens or 0.0,
273-
)
274-
]
269+
return (
270+
self.first_token_iteration or self.request_end_time,
271+
self.prompt_tokens or 0.0,
272+
)
275273

276274
@property
277275
def output_tokens_timings(self) -> list[tuple[float, float]]:
@@ -332,4 +330,4 @@ def total_tokens_timings(self) -> list[tuple[float, float]]:
332330
prompt_timings = self.prompt_tokens_timing
333331
output_timings = self.output_tokens_timings
334332

335-
return (prompt_timings or []) + output_timings
333+
return ([prompt_timings] if prompt_timings else []) + output_timings

src/guidellm/schemas/statistics.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def from_values_function(
714714
def _extract_values(
715715
_objs: Sequence[FunctionObjT],
716716
) -> Sequence[float | tuple[float, float]]:
717-
_outputs = []
717+
_outputs: list[float | tuple[float, float]] = []
718718
for _obj in _objs:
719719
if (_result := function(_obj)) is None:
720720
continue
@@ -830,7 +830,7 @@ def rate_distribution_from_timings_function(
830830
def _extract_values(
831831
_objs: Sequence[FunctionObjT],
832832
) -> Sequence[float | tuple[float, float]]:
833-
_outputs = []
833+
_outputs: list[float | tuple[float, float]] = []
834834
for _obj in _objs:
835835
if (_result := function(_obj)) is None:
836836
continue
@@ -955,7 +955,7 @@ def concurrency_distribution_from_timings_function(
955955
def _extract_values(
956956
_objs: Sequence[FunctionObjT],
957957
) -> Sequence[tuple[float, float] | tuple[float, float, float]]:
958-
_outputs = []
958+
_outputs: list[tuple[float, float] | tuple[float, float, float]] = []
959959
for _obj in _objs:
960960
if (_result := function(_obj)) is None:
961961
continue
@@ -979,10 +979,13 @@ def _extract_values(
979979
@classmethod
980980
def _combine_status_arrays(
981981
cls,
982-
successful: Sequence[float] | np.ndarray,
983-
incomplete: Sequence[float] | np.ndarray,
984-
errored: Sequence[float] | np.ndarray,
985-
num_values_per_item: int,
982+
successful: Sequence[float | tuple[float, float] | tuple[float, float, float]]
983+
| np.ndarray,
984+
incomplete: Sequence[float | tuple[float, float] | tuple[float, float, float]]
985+
| np.ndarray,
986+
errored: Sequence[float | tuple[float, float] | tuple[float, float, float]]
987+
| np.ndarray,
988+
num_values_per_item: Literal[2, 3],
986989
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
987990
successful_array = DistributionSummary._to_weighted_ndarray( # noqa: SLF001
988991
successful, num_values_per_item=num_values_per_item

src/guidellm/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class Settings(BaseSettings):
162162
preferred_output_tokens_source: Literal["request", "response"] = "response"
163163
preferred_backend: Literal["openai"] = "openai"
164164
preferred_route: Literal["text_completions", "chat_completions"] = (
165-
"text_completions"
165+
"chat_completions"
166166
)
167167
openai: OpenAISettings = OpenAISettings()
168168

0 commit comments

Comments
 (0)