fix #5542
This commit is contained in:
parent
45841bb646
commit
fe7ffccdb9
|
@ -15,7 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
for attr in fields(cls):
|
||||
if not attr.init:
|
||||
arg_dict.pop(attr.name)
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
init_args, lazy_args = {}, {}
|
||||
for attr in fields(source):
|
||||
if attr.init:
|
||||
init_args[attr.name] = getattr(source, attr.name)
|
||||
else:
|
||||
lazy_args[attr.name] = getattr(source, attr.name)
|
||||
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||
return new_arg
|
||||
init_args.update(kwargs)
|
||||
result = cls(**init_args)
|
||||
for name, value in lazy_args.items():
|
||||
setattr(result, name, value)
|
||||
|
||||
return result
|
||||
|
|
|
@ -21,6 +21,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
|||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
|
@ -128,6 +129,7 @@ def load_model(
|
|||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
@ -26,7 +27,12 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def apply_liger_kernel(
|
||||
config: "PretrainedConfig",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
require_logits: bool,
|
||||
) -> None:
|
||||
if not is_trainable or not model_args.enable_liger_kernel:
|
||||
return
|
||||
|
||||
|
@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
|||
logger.warning("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
apply_liger_kernel()
|
||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||
logger.info("Current training stage does not support chunked cross entropy.")
|
||||
kwargs = {"fused_linear_cross_entropy": False}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
|
|
|
@ -27,7 +27,6 @@ from ..extras.misc import infer_optim_dtype
|
|||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.liger_kernel import configure_liger_kernel
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
|
@ -93,7 +92,6 @@ def patch_config(
|
|||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_liger_kernel(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
|
|
Loading…
Reference in New Issue