Skip to content

Commit

Permalink
simplify logic (#1856)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Aug 24, 2024
1 parent 77a4b9c commit 22f4eaf
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,19 +589,12 @@ def load_model(

# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
if cfg.s2_attention:
pass
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if not cfg.sample_packing and cfg.s2_attention:
pass
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
Expand Down

0 comments on commit 22f4eaf

Please sign in to comment.