Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.trust_remote_code = args.trust_remote_code
generation_config.valid_sequence_lengths = None
generation_config.attn_batch_split = args.attn_batch_split
generation_config.fp8 = bool(args.quant_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should add this to the generation config as there could be other ways of running the model in fp8. Maybe we could simply check the value of the env variable QUANT_CONFIG in the modeling file since this is how args.quant_config is set up:

args.quant_config = os.getenv("QUANT_CONFIG", "")

WDYT?


return generation_config

Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,9 @@ def generate(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)

# determine if the model is fp8
model_kwargs["fp8"] = generation_config.fp8

# determine whether introduce trim_logits feature
model_kwargs["trim_logits"] = generation_config.trim_logits

Expand Down
34 changes: 33 additions & 1 deletion optimum/habana/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from transformers.processing_utils import Unpack
from transformers.utils import logging

import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.hpu as hthpu

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
Expand Down Expand Up @@ -89,6 +92,7 @@ def gaudi_eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
attn_softmax_bf16: Optional[bool] = False,
**kwargs,
):
bsz, q_len = kwargs["input_shape"]
Expand All @@ -100,8 +104,12 @@ def gaudi_eager_attention_forward(
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=value_states.dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = module.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim)
Expand Down Expand Up @@ -138,6 +146,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -234,6 +243,7 @@ def forward(
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
input_shape=input_shape,
attn_softmax_bf16=attn_softmax_bf16,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -267,6 +277,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -294,6 +305,7 @@ def forward(
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
)
attn_outputs = self.resid_dropout(attn_outputs)

Expand Down Expand Up @@ -329,6 +341,9 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
lazy_mode: Optional[bool] = True,
attn_softmax_bf16: Optional[bool] = False,
fp8: Optional[bool] = False,
**kwargs,
) -> BaseModelOutputWithPast:
"""
Expand Down Expand Up @@ -405,6 +420,15 @@ def forward(
next_decoder_cache = () if not use_new_cache else None

for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if (
lazy_mode
and not self.training
and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
and not fp8 # mark_step impacts fp8 performance
and hthpu.get_device_name() != "GAUDI3" # mark_step improves gaudi 2 performance but impacts gaudi 3
):
htcore.mark_step()

if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand All @@ -419,6 +443,7 @@ def forward(
use_cache,
cache_position,
None,
attn_softmax_bf16=attn_softmax_bf16
)
else:
layer_outputs = decoder_layer(
Expand All @@ -432,6 +457,7 @@ def forward(
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -483,6 +509,7 @@ def forward(
reuse_cache: Optional[bool] = False,
trim_logits: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
"""
Expand Down Expand Up @@ -511,6 +538,8 @@ def forward(
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
fp8=kwargs.get("fp8"),
)

hidden_states = outputs.last_hidden_state
Expand Down Expand Up @@ -609,6 +638,9 @@ def prepare_inputs_for_generation(
"reuse_cache": kwargs.get("reuse_cache"),
"trim_logits": kwargs.get("trim_logits"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"fp8": kwargs.get("fp8"),
}
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def _test_text_generation(
if "moonlight-16b-a3b" in model_name.lower():
command += ["--trim_logits", "--trust_remote_code_tokenizer"]

if "phi-2" in model_name.lower():
command += ["--attn_softmax_bf16"]

if (reuse_cache or torch_compile) and not parallel_strategy == "tp" and not is_starcoder_first_gen_model:
command += ["--reuse_cache"]

Expand Down