This commit is contained in:
hiyouga 2024-09-30 23:28:55 +08:00
parent 45841bb646
commit fe7ffccdb9
4 changed files with 30 additions and 20 deletions

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": def copyfrom(cls, source: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() init_args, lazy_args = {}, {}
arg_dict.update(**kwargs) for attr in fields(source):
for attr in fields(cls): if attr.init:
if not attr.init: init_args[attr.name] = getattr(source, attr.name)
arg_dict.pop(attr.name) else:
lazy_args[attr.name] = getattr(source, attr.name)
new_arg = cls(**arg_dict) init_args.update(kwargs)
new_arg.compute_dtype = old_arg.compute_dtype result = cls(**init_args)
new_arg.device_map = old_arg.device_map for name, value in lazy_args.items():
new_arg.model_max_length = old_arg.model_max_length setattr(result, name, value)
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg return result

View File

@ -21,6 +21,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
@ -128,6 +129,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None model = None
lazy_load = False lazy_load = False

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -26,7 +27,12 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel: if not is_trainable or not model_args.enable_liger_kernel:
return return
@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
logger.warning("Current model does not support liger kernel.") logger.warning("Current model does not support liger kernel.")
return return
apply_liger_kernel() if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.") logger.info("Liger kernel has been applied to the model.")

View File

@ -27,7 +27,6 @@ from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing from .model_utils.packing import configure_packing
@ -93,7 +92,6 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)