1717from typing import Any , Literal
1818
1919import yaml
20- from pydantic import ConfigDict , Field , model_serializer
20+ from pydantic import (
21+ AliasChoices ,
22+ AliasGenerator ,
23+ ConfigDict ,
24+ Field ,
25+ ValidationError ,
26+ ValidatorFunctionWrapHandler ,
27+ field_validator ,
28+ model_serializer ,
29+ )
2130from torch .utils .data import Sampler
2231from transformers import PreTrainedTokenizerBase
2332
@@ -101,9 +110,8 @@ def create(
101110 scenario_data = scenario_data ["args" ]
102111 constructor_kwargs .update (scenario_data )
103112
104- for key , value in kwargs .items ():
105- if value != cls .get_default (key ):
106- constructor_kwargs [key ] = value
113+ # Apply overrides from kwargs
114+ constructor_kwargs .update (kwargs )
107115
108116 return cls .model_validate (constructor_kwargs )
109117
@@ -138,6 +146,14 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
138146 use_enum_values = True ,
139147 from_attributes = True ,
140148 arbitrary_types_allowed = True ,
149+ validate_by_alias = True ,
150+ validate_by_name = True ,
151+ alias_generator = AliasGenerator (
152+ # Support field names with hyphens
153+ validation_alias = lambda field_name : AliasChoices (
154+ field_name , field_name .replace ("_" , "-" )
155+ ),
156+ ),
141157 )
142158
143159 # Required
@@ -151,7 +167,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
151167 profile : StrategyType | ProfileType | Profile = Field (
152168 default = "sweep" , description = "Benchmark profile or scheduling strategy type"
153169 )
154- rate : float | list [float ] | None = Field (
170+ rate : list [float ] | None = Field (
155171 default = None , description = "Request rate(s) for rate-based scheduling"
156172 )
157173 # Backend configuration
@@ -187,6 +203,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
187203 data_request_formatter : RequestFormatter | dict [str , str ] | str = Field (
188204 default = "chat_completions" ,
189205 description = "Request formatting preprocessor or template name" ,
206+ validation_alias = AliasChoices (
207+ "data_request_formatter" ,
208+ "data-request-formatter" ,
209+ "request_type" ,
210+ "request-type" ,
211+ ),
190212 )
191213 data_collator : Callable | Literal ["generative" ] | None = Field (
192214 default = "generative" , description = "Data collator for batch processing"
@@ -243,6 +265,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
243265 default = None , description = "Maximum global error rate (0-1) before stopping"
244266 )
245267
268+ @field_validator ("data" , "data_args" , "rate" , mode = "wrap" )
269+ @classmethod
270+ def single_to_list (
271+ cls , value : Any , handler : ValidatorFunctionWrapHandler
272+ ) -> list [Any ]:
273+ """
274+ Ensures field is always a list.
275+
276+ :param value: Input value for the 'data' field
277+ :return: List of data sources
278+ """
279+ try :
280+ return handler (value )
281+ except ValidationError as err :
282+ # If validation fails, try wrapping the value in a list
283+ if err .errors ()[0 ]["type" ] == "list_type" :
284+ return handler ([value ])
285+ else :
286+ raise
287+
246288 @model_serializer
247289 def serialize_model (self ) -> dict [str , Any ]:
248290 """
0 commit comments