Skip to content

Commit 5e8862e

Browse files
vrdn-23hmellor
andauthored
[Feature] Pydantic validation for scheduler.py and structured_outputs.py (vllm-project#26519)
Signed-off-by: Vinay Damodaran <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent 9e5bd30 commit 5e8862e

File tree

4 files changed

+39
-35
lines changed

4 files changed

+39
-35
lines changed

vllm/config/scheduler.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5-
from dataclasses import InitVar, field
5+
from collections.abc import Callable
6+
from dataclasses import InitVar
67
from typing import Any, Literal
78

8-
from pydantic import SkipValidation, model_validator
9+
from pydantic import Field, field_validator, model_validator
910
from pydantic.dataclasses import dataclass
1011
from typing_extensions import Self
1112

@@ -31,28 +32,28 @@ class SchedulerConfig:
3132
runner_type: RunnerType = "generate"
3233
"""The runner type to launch for the model."""
3334

34-
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
35+
max_num_batched_tokens: int = Field(default=None, ge=1)
3536
"""Maximum number of tokens to be processed in a single iteration.
3637
3738
This config has no static default. If left unspecified by the user, it will
3839
be set in `EngineArgs.create_engine_config` based on the usage context."""
3940

40-
max_num_seqs: SkipValidation[int] = None # type: ignore
41+
max_num_seqs: int = Field(default=None, ge=1)
4142
"""Maximum number of sequences to be processed in a single iteration.
4243
4344
This config has no static default. If left unspecified by the user, it will
4445
be set in `EngineArgs.create_engine_config` based on the usage context."""
4546

46-
max_model_len: SkipValidation[int] = None # type: ignore
47+
max_model_len: int = Field(default=None, ge=1)
4748
"""Maximum length of a sequence (including prompt and generated text). This
4849
is primarily set in `ModelConfig` and that value should be manually
4950
duplicated here."""
5051

51-
max_num_partial_prefills: int = 1
52+
max_num_partial_prefills: int = Field(default=1, ge=1)
5253
"""For chunked prefill, the maximum number of sequences that can be
5354
partially prefilled concurrently."""
5455

55-
max_long_partial_prefills: int = 1
56+
max_long_partial_prefills: int = Field(default=1, ge=1)
5657
"""For chunked prefill, the maximum number of prompts longer than
5758
long_prefill_token_threshold that will be prefilled concurrently. Setting
5859
this less than max_num_partial_prefills will allow shorter prompts to jump
@@ -62,7 +63,7 @@ class SchedulerConfig:
6263
"""For chunked prefill, a request is considered long if the prompt is
6364
longer than this number of tokens."""
6465

65-
num_lookahead_slots: int = 0
66+
num_lookahead_slots: int = Field(default=0, ge=0)
6667
"""The number of slots to allocate per sequence per
6768
step, beyond the known token ids. This is used in speculative
6869
decoding to store KV activations of tokens which may or may not be
@@ -71,7 +72,7 @@ class SchedulerConfig:
7172
NOTE: This will be replaced by speculative config in the future; it is
7273
present to enable correctness tests until then."""
7374

74-
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
75+
enable_chunked_prefill: bool = Field(default=None)
7576
"""If True, prefill requests can be chunked based
7677
on the remaining max_num_batched_tokens."""
7778

@@ -86,14 +87,14 @@ class SchedulerConfig:
8687
"""
8788

8889
# TODO (ywang96): Make this configurable.
89-
max_num_encoder_input_tokens: int = field(init=False)
90+
max_num_encoder_input_tokens: int = Field(init=False)
9091
"""Multimodal encoder compute budget, only used in V1.
9192
9293
NOTE: This is not currently configurable. It will be overridden by
9394
max_num_batched_tokens in case max multimodal embedding size is larger."""
9495

9596
# TODO (ywang96): Make this configurable.
96-
encoder_cache_size: int = field(init=False)
97+
encoder_cache_size: int = Field(init=False)
9798
"""Multimodal encoder cache size, only used in V1.
9899
99100
NOTE: This is not currently configurable. It will be overridden by
@@ -106,7 +107,7 @@ class SchedulerConfig:
106107
- "priority" means requests are handled based on given priority (lower
107108
value means earlier handling) and time of arrival deciding any ties)."""
108109

109-
chunked_prefill_enabled: bool = field(init=False)
110+
chunked_prefill_enabled: bool = Field(init=False)
110111
"""True if chunked prefill is enabled."""
111112

112113
disable_chunked_mm_input: bool = False
@@ -155,6 +156,20 @@ def compute_hash(self) -> str:
155156
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
156157
return hash_str
157158

159+
@field_validator(
160+
"max_num_batched_tokens",
161+
"max_num_seqs",
162+
"max_model_len",
163+
"enable_chunked_prefill",
164+
mode="wrap",
165+
)
166+
@classmethod
167+
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
168+
"""Skip validation if the value is `None` when initialisation is delayed."""
169+
if value is None:
170+
return value
171+
return handler(value)
172+
158173
def __post_init__(self, is_encoder_decoder: bool) -> None:
159174
if self.max_model_len is None:
160175
self.max_model_len = 8192
@@ -260,19 +275,7 @@ def _verify_args(self) -> Self:
260275
self.max_num_seqs * self.max_model_len,
261276
)
262277

263-
if self.num_lookahead_slots < 0:
264-
raise ValueError(
265-
"num_lookahead_slots "
266-
f"({self.num_lookahead_slots}) must be greater than or "
267-
"equal to 0."
268-
)
269-
270-
if self.max_num_partial_prefills < 1:
271-
raise ValueError(
272-
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
273-
"must be greater than or equal to 1."
274-
)
275-
elif self.max_num_partial_prefills > 1:
278+
if self.max_num_partial_prefills > 1:
276279
if not self.chunked_prefill_enabled:
277280
raise ValueError(
278281
"Chunked prefill must be enabled to set "
@@ -286,13 +289,10 @@ def _verify_args(self) -> Self:
286289
f"than the max_model_len ({self.max_model_len})."
287290
)
288291

289-
if (self.max_long_partial_prefills < 1) or (
290-
self.max_long_partial_prefills > self.max_num_partial_prefills
291-
):
292+
if self.max_long_partial_prefills > self.max_num_partial_prefills:
292293
raise ValueError(
293-
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
294-
"must be greater than or equal to 1 and less than or equal to "
295-
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
294+
f"{self.max_long_partial_prefills=} must be less than or equal to "
295+
f"{self.max_num_partial_prefills=}."
296296
)
297297

298298
return self

vllm/config/structured_outputs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5-
from typing import Any, Literal
5+
from typing import Any, Literal, Self
66

7+
from pydantic import model_validator
78
from pydantic.dataclasses import dataclass
89

910
from vllm.config.utils import config
@@ -56,7 +57,8 @@ def compute_hash(self) -> str:
5657
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
5758
return hash_str
5859

59-
def __post_init__(self):
60+
@model_validator(mode="after")
61+
def _validate_structured_output_config(self) -> Self:
6062
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
6163
raise ValueError(
6264
"disable_any_whitespace is only supported for "
@@ -67,3 +69,4 @@ def __post_init__(self):
6769
"disable_additional_properties is only supported "
6870
"for the guidance backend."
6971
)
72+
return self

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ def _set_default_args(
18071807
incremental_prefill_supported = (
18081808
pooling_type is not None
18091809
and pooling_type.lower() == "last"
1810-
and is_causal
1810+
and bool(is_causal)
18111811
)
18121812

18131813
action = "Enabling" if incremental_prefill_supported else "Disabling"

vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import json
5-
import re
65
import uuid
76
from collections.abc import Sequence
87
from typing import Any
98

9+
import regex as re
10+
1011
from vllm.entrypoints.openai.protocol import (
1112
ChatCompletionRequest,
1213
DeltaFunctionCall,

0 commit comments

Comments
 (0)