diff --git a/README.md b/README.md index 4b42edd7..443c8cf7 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. -- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. +- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. @@ -341,7 +341,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. diff --git a/README_zh.md b/README_zh.md index 3926c09d..d5172a7d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 -- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 +- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 @@ -341,7 +341,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 diff --git a/setup.py b/setup.py index 8254b6d4..d43c311c 100644 --- a/setup.py +++ b/setup.py @@ -39,12 +39,14 @@ extra_require = { "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"], - "vllm": ["vllm>=0.4.3"], - "galore": ["galore-torch"], - "badam": ["badam>=1.2.1"], + "hqq": ["hqq"], + "eetq": ["eetq"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], + "vllm": ["vllm>=0.4.3"], + "galore": ["galore-torch"], + "badam": ["badam>=1.2.1"], "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], "dev": ["ruff", "pytest"], diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index ab387231..14876048 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -1,4 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 3f21145d..087c8c38 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -77,6 +77,10 @@ class ModelArguments: default=True, metadata={"help": "Whether or not to use memory-efficient model loading."}, ) + quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( + default="bitsandbytes", + metadata={"help": "Quantization method to use for on-the-fly quantization."}, + ) quantization_bit: Optional[int] = field( default=None, metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, @@ -235,9 +239,6 @@ class ModelArguments: if self.new_special_tokens is not None: # support multiple special tokens self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] - assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." - if self.export_quantization_bit is not None and self.export_quantization_dataset is None: raise ValueError("Quantization dataset is necessary for exporting.") diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py index 4abbaa1b..48cfe76c 100644 --- a/src/llamafactory/model/__init__.py +++ b/src/llamafactory/model/__init__.py @@ -14,10 +14,12 @@ from .loader import load_config, load_model, load_tokenizer from .model_utils.misc import find_all_linear_modules +from .model_utils.quantization import QuantizationMethod from .model_utils.valuehead import load_valuehead_params __all__ = [ + "QuantizationMethod", "load_config", "load_model", "load_tokenizer", diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index e1015821..1261d17a 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -186,11 +186,11 @@ def load_model( trainable_params, all_param = count_parameters(model) if is_trainable: - param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( trainable_params, all_param, 100 * trainable_params / all_param ) else: - param_stats = "all params: {:d}".format(all_param) + param_stats = "all params: {:,}".format(all_param) logger.info(param_stats) diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index fab61cb8..3203b4aa 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List import torch from datasets import load_dataset -from transformers import BitsAndBytesConfig, GPTQConfig +from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from transformers.utils.versions import require_version @@ -59,7 +59,7 @@ class QuantizationMethod(str, Enum): def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: r""" - Prepares the dataset to perform AutoGPTQ. + Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. """ if os.path.isfile(model_args.export_quantization_dataset): data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) @@ -93,7 +93,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] - samples.append({"input_ids": input_ids, "attention_mask": attention_mask}) + samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()}) return samples @@ -105,7 +105,7 @@ def configure_quantization( init_kwargs: Dict[str, Any], ) -> None: r""" - Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) """ if getattr(config, "quantization_config", None): # ptq if is_deepspeed_zero3_enabled(): @@ -131,6 +131,9 @@ def configure_quantization( logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) elif model_args.export_quantization_bit is not None: # auto-gptq + if model_args.export_quantization_bit not in [8, 4, 3, 2]: + raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") + require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") from accelerate.utils import get_max_memory @@ -146,30 +149,48 @@ def configure_quantization( init_kwargs["max_memory"] = get_max_memory() logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) - elif model_args.quantization_bit is not None: # bnb - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit is not None: # on-the-fly + if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora + ) + else: + raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type, - bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora - ) + # Do not assign device map if: + # 1. deepspeed zero3 or fsdp (train) + # 2. auto quantization device map (inference) + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") - # Do not assign device map if: - # 1. deepspeed zero3 or fsdp (train) - # 2. auto quantization device map (inference) - if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": - if model_args.quantization_bit != 4: - raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") - else: - init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference + logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.HQQ.value: + if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: + raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") - logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) + require_version("hqq", "To fix: pip install hqq") + init_kwargs["quantization_config"] = HqqConfig( + nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 + ) # use ATEN kernel (axis=0) for performance + logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.EETQ.value: + if model_args.quantization_bit != 8: + raise ValueError("EETQ only accepts 8-bit quantization.") + + require_version("eetq", "To fix: pip install eetq") + init_kwargs["quantization_config"] = EetqConfig() + logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 652c341c..8abef920 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -23,7 +23,7 @@ from ..data import Role from ..extras.constants import PEFT_METHODS from ..extras.misc import torch_gc from ..extras.packages import is_gradio_available -from .common import get_save_dir +from .common import QUANTIZATION_BITS, get_save_dir from .locales import ALERTS @@ -76,11 +76,17 @@ class WebChatModel(ChatModel): yield error return + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_path, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 980428a4..bced18f0 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -47,6 +47,8 @@ DEFAULT_CONFIG_DIR = "config" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user_config.yaml" +QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"] +GPTQ_BITS = ["8", "4", "3", "2"] def get_save_dir(*paths: str) -> os.PathLike: diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 14257949..0a938f02 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc from ...extras.packages import is_gradio_available from ...train.tuner import export_model -from ..common import get_save_dir +from ..common import GPTQ_BITS, get_save_dir from ..locales import ALERTS @@ -32,9 +32,6 @@ if TYPE_CHECKING: from ..engine import Engine -GPTQ_BITS = ["8", "4", "3", "2"] - - def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: return gr.Dropdown(value="none", interactive=False) diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index 18b9a7d2..e331d5e4 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -18,7 +18,7 @@ from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available from ..common import get_model_info, list_checkpoints, save_config -from ..utils import can_quantize +from ..utils import can_quantize, can_quantize_to if is_gradio_available(): @@ -43,10 +43,11 @@ def create_top() -> Dict[str, "Component"]: with gr.Accordion(open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) - template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) - rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) - booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1) + quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2) + booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2) visual_inputs = gr.Checkbox(scale=1) model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( @@ -58,6 +59,7 @@ def create_top() -> Dict[str, "Component"]: list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False ) checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) + quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) return dict( lang=lang, @@ -67,6 +69,7 @@ def create_top() -> Dict[str, "Component"]: checkpoint_path=checkpoint_path, advanced_tab=advanced_tab, quantization_bit=quantization_bit, + quantization_method=quantization_method, template=template, rope_scaling=rope_scaling, booster=booster, diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index cd166584..435876e7 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -85,15 +85,29 @@ LOCALES = { "quantization_bit": { "en": { "label": "Quantization bit", - "info": "Enable 4/8-bit model quantization (QLoRA).", + "info": "Enable quantization (QLoRA).", }, "ru": { "label": "Уровень квантования", - "info": "Включить 4/8-битное квантование модели (QLoRA).", + "info": "Включить квантование (QLoRA).", }, "zh": { "label": "量化等级", - "info": "启用 4/8 比特模型量化(QLoRA)。", + "info": "启用量化(QLoRA)。", + }, + }, + "quantization_method": { + "en": { + "label": "Quantization method", + "info": "Quantization algorithm to use.", + }, + "ru": { + "label": "Метод квантования", + "info": "Алгоритм квантования, который следует использовать.", + }, + "zh": { + "label": "量化方法", + "info": "使用的量化算法。", }, }, "template": { diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index 7e9b801a..ebe9f1b9 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -71,6 +71,7 @@ class Manager: self._id_to_elem["top.finetuning_type"], self._id_to_elem["top.checkpoint_path"], self._id_to_elem["top.quantization_bit"], + self._id_to_elem["top.quantization_method"], self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.booster"], diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 549ec765..f7fbac30 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd @@ -104,6 +104,11 @@ class Runner: model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + args = dict( stage=TRAINING_STAGES[get("train.training_stage")], do_train=True, @@ -111,7 +116,8 @@ class Runner: cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", @@ -234,13 +240,19 @@ class Runner: model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + args = dict( stage="sft", model_name_or_path=get("top.model_path"), cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index a616bcba..4f313e4e 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -25,6 +25,7 @@ from yaml import safe_dump, safe_load from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot +from ..model import QuantizationMethod from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir from .locales import ALERTS @@ -55,6 +56,18 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown": return gr.Dropdown(interactive=True) +def can_quantize_to(quantization_method: str) -> "gr.Dropdown": + r""" + Returns the available quantization bits. + """ + if quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + return gr.Dropdown(choices=["none", "8", "4"]) + elif quantization_method == QuantizationMethod.HQQ.value: + return gr.Dropdown(choices=["none", "8", "6", "5", "4", "3", "2", "1"]) + elif quantization_method == QuantizationMethod.EETQ.value: + return gr.Dropdown(choices=["none", "8"]) + + def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: r""" Modifys states after changing the training stage.