fix galore
This commit is contained in:
parent
57452a4aa1
commit
33a4c24a8a
|
@ -70,7 +70,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||
|
||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||
|
||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||
|
||||
|
|
|
@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type freeze \
|
||||
--name_module_trainable mlp,self_attn \
|
||||
--num_layer_trainable 8 \
|
||||
--finetuning_type full \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--optim adamw_8bit \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
|
@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type freeze \
|
||||
--name_module_trainable mlp,self_attn \
|
||||
--num_layer_trainable 8 \
|
||||
--finetuning_type full \
|
||||
--use_galore \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 32 \
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--use_galore \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 32 \
|
||||
--optim adamw_8bit \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
14
setup.py
14
setup.py
|
@ -18,6 +18,19 @@ def get_requires():
|
|||
return lines
|
||||
|
||||
|
||||
extra_require = {
|
||||
"deepspeed": ["deepspeed==0.13.1"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"unsloth": ["unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"],
|
||||
"vllm": ["vllm==0.3.3"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu,cpu]"],
|
||||
"galore": ["galore_torch @ git+https://github.com/jiaweizzhao/GaLore.git"],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
setup(
|
||||
|
@ -35,6 +48,7 @@ def main():
|
|||
packages=find_packages("src"),
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
|
|
|
@ -66,10 +66,6 @@ class LoraArguments:
|
|||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
lora_bf16_mode: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
||||
)
|
||||
use_rslora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
|
@ -194,6 +190,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
|
@ -156,6 +157,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
raise ValueError("This device does not support `pure_bf16`.")
|
||||
|
||||
if training_args.fp16 or training_args.bf16:
|
||||
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
if (
|
||||
|
@ -226,9 +234,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
)
|
||||
|
||||
# Post-process model arguments
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
if training_args.bf16 or finetuning_args.pure_bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
model_args.aqlm_optimization = not training_args.predict_with_generate
|
||||
|
||||
|
|
|
@ -34,7 +34,8 @@ def init_adapter(
|
|||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
if not finetuning_args.pure_bf16:
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
|
@ -78,7 +79,8 @@ def init_adapter(
|
|||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.data = param.data.to(torch.float32)
|
||||
if not finetuning_args.pure_bf16:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
|
@ -150,8 +152,9 @@ def init_adapter(
|
|||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
||||
if not finetuning_args.pure_bf16:
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
|
|
@ -154,14 +154,28 @@ def create_custom_optimzer(
|
|||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adamw_8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||
optimizer = GaLoreAdamW(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
)
|
||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||
optimizer = GaLoreAdamW8bit(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
optim_bits=8,
|
||||
is_paged="paged" in training_args.optim,
|
||||
)
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||
optimizer = GaLoreAdafactor(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
|
||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
return optimizer
|
||||
|
|
Loading…
Reference in New Issue