update liger kernel
This commit is contained in:
parent
aa1afdc756
commit
a7dd7d325e
|
@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
|||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
|
@ -215,14 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
):
|
||||
raise ValueError("Please specify dataset for evaluation.")
|
||||
|
||||
if training_args.predict_with_generate and data_args.eval_dataset is None:
|
||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
||||
if training_args.predict_with_generate:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if training_args.predict_with_generate and finetuning_args.compute_accuracy:
|
||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||
if data_args.eval_dataset is None:
|
||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
||||
|
||||
if training_args.predict_with_generate and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
||||
if finetuning_args.compute_accuracy:
|
||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
@ -231,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
|
||||
raise ValueError("This device does not support `pure_bf16`.")
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
|
|
|
@ -32,12 +32,16 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
|||
|
||||
if getattr(config, "model_type", None) == "gemma":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "gemma2":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "llama":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "mistral":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "mixtral":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "phi3":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "qwen2":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue