diff --git a/scripts/run_evaluation.py b/scripts/run_evaluation.py index bbd44e25c..dde162bb8 100644 --- a/scripts/run_evaluation.py +++ b/scripts/run_evaluation.py @@ -57,6 +57,11 @@ def parse_and_validate_args(): action="store_true", ) parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction) + parser.add_argument( + "--use_flash_attn", + help="Whether to load the model using Flash Attention 2", + action="store_true", + ) parsed_args = parser.parse_args() print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}") @@ -441,7 +446,7 @@ def export_experiment_info( if __name__ == "__main__": args = parse_and_validate_args() - tuned_model = TunedCausalLM.load(args.model) + tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn) eval_data = datasets.load_dataset( "json", data_files=args.data_path, split=args.split ) diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 5d40f1cc4..70820049e 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -142,7 +142,10 @@ def __init__(self, model, tokenizer, device): @classmethod def load( - cls, checkpoint_path: str, base_model_name_or_path: str = None + cls, + checkpoint_path: str, + base_model_name_or_path: str = None, + use_flash_attn: bool = False, ) -> "TunedCausalLM": """Loads an instance of this model. @@ -152,6 +155,8 @@ def load( adapter_config.json. base_model_name_or_path: str [Default: None] Override for the base model to be used. + use_flash_attn: bool [Default: False] + Whether to load the model using flash attention. By default, the paths for the base model and tokenizer are contained within the adapter config of the tuned model. Note that in this context, a path may refer to a model to be @@ -173,14 +178,24 @@ def load( try: with AdapterConfigPatcher(checkpoint_path, overrides): try: - model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path) + model = AutoPeftModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) except OSError as e: print("Failed to initialize checkpoint model!") raise e except FileNotFoundError: print("No adapter config found! Loading as a merged model...") # Unable to find the adapter config; fall back to loading as a merged model - model = AutoModelForCausalLM.from_pretrained(checkpoint_path) + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) device = "cuda" if torch.cuda.is_available() else None print(f"Inferred device: {device}") @@ -246,6 +261,11 @@ def main(): type=int, default=20, ) + parser.add_argument( + "--use_flash_attn", + help="Whether to load the model using Flash Attention 2", + action="store_true", + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--text", help="Text to run inference on") group.add_argument( @@ -261,6 +281,7 @@ def main(): loaded_model = TunedCausalLM.load( checkpoint_path=args.model, base_model_name_or_path=args.base_model_name_or_path, + use_flash_attn=args.use_flash_attn, ) # Run inference on the text; if multiple were provided, process them all