Skip to content

Commit

Permalink
[mypy] Pass type checking in vllm/inputs (vllm-project#11680)
Browse files Browse the repository at this point in the history
Signed-off-by: Tobias Pitters <[email protected]>
  • Loading branch information
CloseChoice authored Jan 2, 2025
1 parent 23c1b10 commit b6087a6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
1 change: 1 addition & 0 deletions tools/mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/executor
run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/plugins
Expand Down
21 changes: 11 additions & 10 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def prompt(self) -> Optional[str]:
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt")

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def prompt_token_ids(self) -> List[int]:
Expand All @@ -259,7 +259,7 @@ def prompt_token_ids(self) -> List[int]:
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt_token_ids", [])

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def token_type_ids(self) -> List[int]:
Expand All @@ -268,7 +268,7 @@ def token_type_ids(self) -> List[int]:
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("token_type_ids", [])

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
Expand All @@ -277,7 +277,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
Expand All @@ -289,7 +289,7 @@ def multi_modal_data(self) -> "MultiModalDataDict":
if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
Expand All @@ -301,7 +301,7 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_hashes(self) -> List[str]:
Expand All @@ -311,9 +311,10 @@ def multi_modal_hashes(self) -> List[str]:
return inputs.get("multi_modal_hashes", [])

if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
# only the case when we use MultiModalInputsV2
return inputs.get("mm_hashes", []) # type: ignore[return-value]

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
Expand All @@ -325,7 +326,7 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
if inputs["type"] == "multimodal":
return inputs.get("mm_placeholders", {})

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]

@cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]:
Expand All @@ -337,7 +338,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]:
if inputs["type"] == "multimodal":
return {}

assert_never(inputs)
assert_never(inputs) # type: ignore[arg-type]


ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
Expand Down
6 changes: 3 additions & 3 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _build_enc_dec_llm_inputs(
or encoder_inputs["type"] == "multimodal"):
pass
else:
assert_never(encoder_inputs)
assert_never(encoder_inputs) # type: ignore[arg-type]

if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
Expand All @@ -452,7 +452,7 @@ def _build_enc_dec_llm_inputs(
raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet")
else:
assert_never(encoder_inputs)
assert_never(encoder_inputs) # type: ignore[arg-type]

return EncoderDecoderInputs(
encoder=encoder_inputs,
Expand Down Expand Up @@ -569,7 +569,7 @@ def _build_decoder_only_llm_inputs(
prompt_adapter_request=prompt_adapter_request,
)
else:
assert_never(prompt_inputs)
assert_never(prompt_inputs) # type: ignore[arg-type]

return prompt_inputs

Expand Down
2 changes: 1 addition & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _ensure_mm_kwargs(
# Be more strict in V2
assert "mm_kwargs" in inputs
else:
assert_never(inputs["type"])
assert_never(inputs["type"]) # type: ignore[arg-type]

def process_input(self, model_config: "ModelConfig",
inputs: ProcessorInputs) -> ProcessorInputs:
Expand Down

0 comments on commit b6087a6

Please sign in to comment.