From df29793dc73a83f3c86c19de967adffda1a28a93 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Mon, 29 Apr 2024 11:01:26 +0900 Subject: [PATCH] [mypy][5/N] Support all typing on model executor (#4427) --- .github/workflows/mypy.yaml | 2 +- format.sh | 2 +- .../lm_format_enforcer_decoding.py | 1 + vllm/model_executor/layers/linear.py | 12 ++++- .../layers/quantization/__init__.py | 4 +- .../layers/quantization/base_config.py | 14 ++++-- .../layers/quantization/squeezellm.py | 5 +- .../model_executor/layers/rotary_embedding.py | 4 +- vllm/model_executor/layers/sampler.py | 47 +++++++++++-------- .../model_executor/model_loader/tensorizer.py | 4 +- 10 files changed, 61 insertions(+), 34 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 089c7d18ad6f2..a19be8525f902 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -43,8 +43,8 @@ jobs: mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml # TODO(sang): Fix nested dir - mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index 4ac1842daef0a..bd12e61d77806 100755 --- a/format.sh +++ b/format.sh @@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml -mypy vllm/model_executor/*.py --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 0d74a5f8e81ff..d0a5ca5592f9d 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: return schema if isinstance(schema, BaseModel): return schema.model_json_schema() + raise AssertionError(f"Unsupported schema type {schema}") @lru_cache diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1bd6c42ab3fd8..4d43ed4c5f14a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -128,7 +128,8 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype if quant_config is None: - self.quant_method = UnquantizedLinearMethod() + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self) @@ -160,6 +161,8 @@ def __init__( super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [self.output_size], self.input_size, self.output_size, self.params_dtype) @@ -173,6 +176,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -221,6 +225,8 @@ def __init__( self.output_size_per_partition = divide(output_size, tp_size) if output_sizes is None: output_sizes = [output_size] + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [x // tp_size for x in output_sizes], @@ -255,6 +261,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. + assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. @@ -579,6 +586,8 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size_per_partition, [self.output_size], @@ -624,6 +633,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. + assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0820f17c5c50d..70e0a7cfe3e3b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Dict, Type from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -QUANTIZATION_METHODS = { +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "fp8": Fp8Config, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index b755b1328504a..ff5cf0b2bd61a 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -76,8 +76,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase: - """Get the quantize method to use for the quantized layer.""" + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 971078fe25a9b..207dbcee8afc5 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -52,11 +52,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": return cls(weight_bits) def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]: + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: if isinstance(layer, LinearBase): return SqueezeLLMLinearMethod(self) - return + return None def get_scaled_act_names(self) -> List[str]: return [] diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b8361af61ae3f..25365a9b50a1f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -431,8 +431,8 @@ def forward( torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( - idx.device) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4ef25edecfd24..d79c99e5d0a45 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -13,6 +13,9 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceGroupOutput, SequenceOutput) +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -155,7 +158,7 @@ def _apply_min_tokens_penalty( have not been generated yet """ # list of indices in logits that will be set to -inf - logits_to_penalize = [] + logits_to_penalize: List[Tuple[int, int]] = [] logits_applied = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -269,7 +272,7 @@ def _apply_min_p( def _greedy_sample( selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run greedy sampling on a given samples. Args: @@ -284,7 +287,7 @@ def _greedy_sample( """ samples = samples.tolist() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -304,7 +307,7 @@ def _greedy_sample( def _random_sample( selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run random sampling on a given samples. Args: @@ -320,7 +323,7 @@ def _random_sample( # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -348,7 +351,7 @@ def _random_sample( def _beam_search_sample( selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run beam sampling on a given samples. Args: @@ -370,7 +373,7 @@ def _beam_search_sample( # NOTE: Beam search is not vectorized, so its speed can be slower than # other sampling methods. sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -391,16 +394,16 @@ def _beam_search_sample( next_token_ids = next_token_ids.tolist() else: # Generation phase. - cumulative_logprobs = [ + cumulative_logprobs: List[int] = [ seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids ] - cumulative_logprobs = torch.tensor( + cumulative_logprobs_tensor = torch.tensor( cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device) seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs.unsqueeze(dim=1)) + cumulative_logprobs_tensor.unsqueeze(dim=1)) _, topk_ids = torch.topk(seq_group_logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() @@ -452,8 +455,10 @@ def _sample_with_torch( sampling_metadata: SamplingMetadata, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -555,8 +560,10 @@ def _sample_with_triton_kernel( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> SampleResultType: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -632,7 +639,7 @@ def _sample( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: """Return sample lobprobs and prompt logprobs. @@ -751,8 +758,8 @@ def _get_logprobs( assert len(next_token_ids) == len(query_indices) if len(query_indices) == 0: - empty_sampled_logprob = [] - empty_prompt_logprob = None + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None return [empty_prompt_logprob], [empty_sampled_logprob] query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) @@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], @@ -1009,7 +1016,7 @@ def _build_sampler_output( ) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: """Get a list of next prompt tokens to compute logprob from a given sequence group. diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 8fc6d16672117..2d654b2fefb8d 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -64,7 +64,7 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs": "s3_secret_access_key": self.s3_secret_access_key, "s3_endpoint": self.s3_endpoint, } - return TensorizerArgs(**tensorizer_args) + return TensorizerArgs(**tensorizer_args) # type: ignore def verify_with_parallel_config( self, @@ -270,8 +270,10 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.model = self._init_model() def _init_model(self): + assert self.tensorizer_config.hf_config is not None model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None with no_init_or_tensor(): return self.tensorizer_config.model_class( config=model_args,