This commit is contained in:
hiyouga 2024-09-30 23:28:55 +08:00
parent 45841bb646
commit fe7ffccdb9
4 changed files with 30 additions and 20 deletions

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field, fields
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}
for attr in fields(source):
if attr.init:
init_args[attr.name] = getattr(source, attr.name)
else:
lazy_args[attr.name] = getattr(source, attr.name)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg
init_args.update(kwargs)
result = cls(**init_args)
for name, value in lazy_args.items():
setattr(result, name, value)
return result

View File

@ -21,6 +21,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
@ -128,6 +129,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
@ -26,7 +27,12 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
logger.warning("Current model does not support liger kernel.")
return
apply_liger_kernel()
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")

View File

@ -27,7 +27,6 @@ from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
@ -93,7 +92,6 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)