fix #5542
This commit is contained in:
parent
45841bb646
commit
fe7ffccdb9
|
@ -15,7 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||||
arg_dict = old_arg.to_dict()
|
init_args, lazy_args = {}, {}
|
||||||
arg_dict.update(**kwargs)
|
for attr in fields(source):
|
||||||
for attr in fields(cls):
|
if attr.init:
|
||||||
if not attr.init:
|
init_args[attr.name] = getattr(source, attr.name)
|
||||||
arg_dict.pop(attr.name)
|
else:
|
||||||
|
lazy_args[attr.name] = getattr(source, attr.name)
|
||||||
|
|
||||||
new_arg = cls(**arg_dict)
|
init_args.update(kwargs)
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
result = cls(**init_args)
|
||||||
new_arg.device_map = old_arg.device_map
|
for name, value in lazy_args.items():
|
||||||
new_arg.model_max_length = old_arg.model_max_length
|
setattr(result, name, value)
|
||||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
|
||||||
return new_arg
|
return result
|
||||||
|
|
|
@ -21,6 +21,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
from .model_utils.misc import register_autoclass
|
from .model_utils.misc import register_autoclass
|
||||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||||
|
@ -128,6 +129,7 @@ def load_model(
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
@ -26,7 +27,12 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def apply_liger_kernel(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
is_trainable: bool,
|
||||||
|
require_logits: bool,
|
||||||
|
) -> None:
|
||||||
if not is_trainable or not model_args.enable_liger_kernel:
|
if not is_trainable or not model_args.enable_liger_kernel:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
||||||
logger.warning("Current model does not support liger kernel.")
|
logger.warning("Current model does not support liger kernel.")
|
||||||
return
|
return
|
||||||
|
|
||||||
apply_liger_kernel()
|
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||||
|
logger.info("Current training stage does not support chunked cross entropy.")
|
||||||
|
kwargs = {"fused_linear_cross_entropy": False}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
apply_liger_kernel(**kwargs)
|
||||||
logger.info("Liger kernel has been applied to the model.")
|
logger.info("Liger kernel has been applied to the model.")
|
||||||
|
|
|
@ -27,7 +27,6 @@ from ..extras.misc import infer_optim_dtype
|
||||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||||
from .model_utils.checkpointing import prepare_model_for_training
|
from .model_utils.checkpointing import prepare_model_for_training
|
||||||
from .model_utils.embedding import resize_embedding_layer
|
from .model_utils.embedding import resize_embedding_layer
|
||||||
from .model_utils.liger_kernel import configure_liger_kernel
|
|
||||||
from .model_utils.longlora import configure_longlora
|
from .model_utils.longlora import configure_longlora
|
||||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .model_utils.packing import configure_packing
|
from .model_utils.packing import configure_packing
|
||||||
|
@ -93,7 +92,6 @@ def patch_config(
|
||||||
|
|
||||||
configure_attn_implementation(config, model_args, is_trainable)
|
configure_attn_implementation(config, model_args, is_trainable)
|
||||||
configure_rope(config, model_args, is_trainable)
|
configure_rope(config, model_args, is_trainable)
|
||||||
configure_liger_kernel(config, model_args, is_trainable)
|
|
||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
|
|
Loading…
Reference in New Issue