update liger kernel

This commit is contained in:
hiyouga 2024-08-29 20:46:08 +08:00
parent aa1afdc756
commit a7dd7d325e
2 changed files with 13 additions and 8 deletions

View File

@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
@ -215,14 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate and data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if training_args.predict_with_generate and finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if training_args.predict_with_generate and is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
@ -231,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled():

View File

@ -32,12 +32,16 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
if getattr(config, "model_type", None) == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif getattr(config, "model_type", None) == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif getattr(config, "model_type", None) == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif getattr(config, "model_type", None) == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif getattr(config, "model_type", None) == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif getattr(config, "model_type", None) == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif getattr(config, "model_type", None) == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
else: