diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 52a170c49d19..b4cd68015909 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -219,12 +219,13 @@ class TransformerBlock(torch.nn.Module): def __init__( self, config: GptOssConfig, + cache_config: CacheConfig, quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.attn = OAIAttention(config, cache_config=cache_config, prefix=f"{prefix}.attn") self.mlp = MLPBlock(config, self.layer_idx, quant_config=quant_config, @@ -248,6 +249,7 @@ def __init__( ): super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config self.quant_config = vllm_config.quant_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( @@ -257,6 +259,7 @@ def __init__( self.layers = torch.nn.ModuleList([ TransformerBlock( self.config, + cache_config=self.cache_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, f"block.{layer_idx}"), ) for layer_idx in range(self.config.num_hidden_layers)