add export_device in webui #3333

This commit is contained in:
hiyouga 2024-04-25 19:02:32 +08:00
parent a3aad4b4f0
commit 3a7c1286ce
7 changed files with 57 additions and 22 deletions

View File

@ -28,9 +28,9 @@ examples/
│ ├── merge.sh: Merge LoRA weights into the pre-trained models │ ├── merge.sh: Merge LoRA weights into the pre-trained models
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ │ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
├── inference/ ├── inference/
│ ├── cli_demo.sh: Launch a command line interface with LoRA adapters │ ├── cli_demo.sh: Chat with fine-tuned model in the CLI with LoRA adapters
│ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters │ ├── api_demo.sh: Chat with fine-tuned model in an OpenAI-style API with LoRA adapters
│ ├── web_demo.sh: Launch a web interface with LoRA adapters │ ├── web_demo.sh: Chat with fine-tuned model in the Web browser with LoRA adapters
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters │ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
└── extras/ └── extras/
├── galore/ ├── galore/

View File

@ -8,4 +8,5 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--finetuning_type lora \ --finetuning_type lora \
--export_dir ../../models/llama2-7b-sft \ --export_dir ../../models/llama2-7b-sft \
--export_size 2 \ --export_size 2 \
--export_device cpu \
--export_legacy_format False --export_legacy_format False

View File

@ -139,7 +139,7 @@ class ModelArguments:
) )
export_device: str = field( export_device: str = field(
default="cpu", default="cpu",
metadata={"help": "The device used in model export."}, metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, default=None,

View File

@ -12,7 +12,7 @@ from .utils.attention import configure_attn_implementation, print_attn_implement
from .utils.checkpointing import prepare_model_for_training from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization from .utils.quantization import configure_quantization
from .utils.rope import configure_rope from .utils.rope import configure_rope
@ -46,17 +46,12 @@ def patch_config(
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info("Using KV cache for faster generation.")
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn) setattr(config, "use_flash_attn", model_args.flash_attn)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
@ -65,9 +60,6 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
setattr(config, "output_router_logits", True)
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage

View File

@ -5,7 +5,9 @@ from transformers.utils.versions import require_version
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None: def add_z3_leaf_module(model: "PreTrainedModel") -> None:
@ -37,3 +39,15 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN]) set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
@ -26,9 +27,10 @@ def save_model(
adapter_path: List[str], adapter_path: List[str],
finetuning_type: str, finetuning_type: str,
template: str, template: str,
max_shard_size: int, export_size: int,
export_quantization_bit: int, export_quantization_bit: int,
export_quantization_dataset: str, export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool, export_legacy_format: bool,
export_dir: str, export_dir: str,
export_hub_model_id: str, export_hub_model_id: str,
@ -44,6 +46,8 @@ def save_model(
error = ALERTS["err_no_dataset"][lang] error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path: elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
error = ALERTS["err_no_adapter"][lang] error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
error = ALERTS["err_gptq_lora"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
@ -64,22 +68,25 @@ def save_model(
template=template, template=template,
export_dir=export_dir, export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None, export_hub_model_id=export_hub_model_id or None,
export_size=max_shard_size, export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None, export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
export_model(args) export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang] yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1) export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox() export_legacy_format = gr.Checkbox()
with gr.Row(): with gr.Row():
@ -98,9 +105,10 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.adapter_path"), engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"), engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"), engine.manager.get_elem_by_id("top.template"),
max_shard_size, export_size,
export_quantization_bit, export_quantization_bit,
export_quantization_dataset, export_quantization_dataset,
export_device,
export_legacy_format, export_legacy_format,
export_dir, export_dir,
export_hub_model_id, export_hub_model_id,
@ -109,9 +117,10 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
return dict( return dict(
max_shard_size=max_shard_size, export_size=export_size,
export_quantization_bit=export_quantization_bit, export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
export_dir=export_dir, export_dir=export_dir,
export_hub_model_id=export_hub_model_id, export_hub_model_id=export_hub_model_id,

View File

@ -1150,7 +1150,7 @@ LOCALES = {
"value": "清空历史", "value": "清空历史",
}, },
}, },
"max_shard_size": { "export_size": {
"en": { "en": {
"label": "Max shard size (GB)", "label": "Max shard size (GB)",
"info": "The maximum size for a model file.", "info": "The maximum size for a model file.",
@ -1192,6 +1192,20 @@ LOCALES = {
"info": "量化过程中使用的校准数据集。", "info": "量化过程中使用的校准数据集。",
}, },
}, },
"export_device": {
"en": {
"label": "Export device",
"info": "Which device should be used to export model.",
},
"ru": {
"label": "Экспорт устройство",
"info": "Какое устройство следует использовать для экспорта модели.",
},
"zh": {
"label": "导出设备",
"info": "导出模型使用的设备类型。",
},
},
"export_legacy_format": { "export_legacy_format": {
"en": { "en": {
"label": "Export legacy format", "label": "Export legacy format",
@ -1287,7 +1301,12 @@ ALERTS = {
"err_no_export_dir": { "err_no_export_dir": {
"en": "Please provide export dir.", "en": "Please provide export dir.",
"ru": "Пожалуйста, укажите каталог для экспорта.", "ru": "Пожалуйста, укажите каталог для экспорта.",
"zh": "请填写导出目录", "zh": "请填写导出目录。",
},
"err_gptq_lora": {
"en": "Please merge adapters before quantizing the model.",
"ru": "Пожалуйста, объедините адаптеры перед квантованием модели.",
"zh": "量化模型前请先合并适配器。",
}, },
"err_failed": { "err_failed": {
"en": "Failed.", "en": "Failed.",