diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 96e2c8a9..dfb42a9f 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -37,7 +37,10 @@ def configure_attn_implementation( if is_flash_attn_2_available(): require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") - logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + + if model_args.flash_attn != "fa2": + logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + model_args.flash_attn = "fa2" else: logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")