diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 2a6454be75..2b8c7afb6c 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -728,6 +728,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.trust_remote_code = args.trust_remote_code generation_config.valid_sequence_lengths = None generation_config.attn_batch_split = args.attn_batch_split + generation_config.fp8 = bool(args.quant_config) return generation_config diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a8b1858e99..bdd34dbc31 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1503,6 +1503,9 @@ def generate( generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device ) + # determine if the model is fp8 + model_kwargs["fp8"] = generation_config.fp8 + # determine whether introduce trim_logits feature model_kwargs["trim_logits"] = generation_config.trim_logits diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index b50c3547ac..25d65f161c 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -37,6 +37,9 @@ from transformers.processing_utils import Unpack from transformers.utils import logging +import habana_frameworks.torch.core as htcore +import habana_frameworks.torch.hpu as hthpu + from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -89,6 +92,7 @@ def gaudi_eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, + attn_softmax_bf16: Optional[bool] = False, **kwargs, ): bsz, q_len = kwargs["input_shape"] @@ -100,8 +104,12 @@ def gaudi_eager_attention_forward( if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=value_states.dtype) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = module.matmul_av(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim) @@ -138,6 +146,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -234,6 +243,7 @@ def forward( dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, input_shape=input_shape, + attn_softmax_bf16=attn_softmax_bf16, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -267,6 +277,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -294,6 +305,7 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -329,6 +341,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + lazy_mode: Optional[bool] = True, + attn_softmax_bf16: Optional[bool] = False, + fp8: Optional[bool] = False, **kwargs, ) -> BaseModelOutputWithPast: """ @@ -405,6 +420,15 @@ def forward( next_decoder_cache = () if not use_new_cache else None for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + and not fp8 # mark_step impacts fp8 performance + and hthpu.get_device_name() != "GAUDI3" # mark_step improves gaudi 2 performance but impacts gaudi 3 + ): + htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -419,6 +443,7 @@ def forward( use_cache, cache_position, None, + attn_softmax_bf16=attn_softmax_bf16 ) else: layer_outputs = decoder_layer( @@ -432,6 +457,7 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16 ) hidden_states = layer_outputs[0] @@ -483,6 +509,7 @@ def forward( reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: """ @@ -511,6 +538,8 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16, + fp8=kwargs.get("fp8"), ) hidden_states = outputs.last_hidden_state @@ -609,6 +638,9 @@ def prepare_inputs_for_generation( "reuse_cache": kwargs.get("reuse_cache"), "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "fp8": kwargs.get("fp8"), } ) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index b058506a85..894a5e93d8 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -231,6 +231,9 @@ def _test_text_generation( if "moonlight-16b-a3b" in model_name.lower(): command += ["--trim_logits", "--trust_remote_code_tokenizer"] + if "phi-2" in model_name.lower(): + command += ["--attn_softmax_bf16"] + if (reuse_cache or torch_compile) and not parallel_strategy == "tp" and not is_starcoder_first_gen_model: command += ["--reuse_cache"]