add adam_mini to readme
This commit is contained in:
parent
ef482394f0
commit
e2a28f51c6
14
README.md
14
README.md
|
@ -49,7 +49,7 @@ Choose your path:
|
||||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
- **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
@ -71,14 +71,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank @relic-yuexi for PR.
|
||||||
|
|
||||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||||
|
|
||||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
||||||
|
|
||||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
@ -91,7 +93,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
|
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
||||||
|
|
||||||
|
@ -103,7 +105,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
|
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
||||||
|
|
||||||
|
@ -342,7 +344,7 @@ cd LLaMA-Factory
|
||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
|
|
14
README_zh.md
14
README_zh.md
|
@ -49,7 +49,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
- **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
@ -71,14 +71,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 @relic-yuexi 的 PR。
|
||||||
|
|
||||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||||
|
|
||||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
||||||
|
|
||||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
@ -91,7 +93,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
|
|
||||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
|
@ -103,7 +105,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
|
|
||||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
|
||||||
|
|
||||||
|
@ -342,7 +344,7 @@ cd LLaMA-Factory
|
||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
|
|
|
@ -189,6 +189,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Full-Parameter Fine-Tuning using Adam-mini
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/adam_mini/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### LoRA+ Fine-Tuning
|
#### LoRA+ Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -189,6 +189,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 使用 Adam-mini 进行全参数训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/adam_mini/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### LoRA+ 微调
|
#### LoRA+ 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
### model
|
||||||
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
use_adam_mini: true
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: llama3
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/llama3-8b/full/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
|
@ -34,6 +34,7 @@ num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
pure_bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|
|
@ -2,5 +2,5 @@
|
||||||
|
|
||||||
python scripts/llama_pro.py \
|
python scripts/llama_pro.py \
|
||||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
--output_dir models/llama3-8b-instruct-pro \
|
--output_dir models/llama3-8b-pro \
|
||||||
--num_expand 8
|
--num_expand 8
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: models/llama3-8b-instruct-pro
|
model_name_or_path: models/llama3-8b-pro
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
|
@ -18,7 +18,7 @@ overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: saves/llama3-8b-instruct-pro/freeze/sft
|
output_dir: saves/llama3-8b-pro/freeze/sft
|
||||||
logging_steps: 10
|
logging_steps: 10
|
||||||
save_steps: 500
|
save_steps: 500
|
||||||
plot_loss: true
|
plot_loss: true
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -47,6 +47,7 @@ extra_require = {
|
||||||
"vllm": ["vllm>=0.4.3"],
|
"vllm": ["vllm>=0.4.3"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"dev": ["ruff", "pytest"],
|
||||||
|
|
|
@ -326,6 +326,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
|
use_adam_mini: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
||||||
|
)
|
||||||
freeze_vision_tower: bool = field(
|
freeze_vision_tower: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
||||||
|
@ -342,10 +346,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
use_adammini: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to use AdamMini optimizer."},
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
def split_arg(arg):
|
def split_arg(arg):
|
||||||
|
|
|
@ -128,6 +128,9 @@ def _check_extra_dependencies(
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
||||||
|
|
||||||
|
if finetuning_args.use_adam_mini:
|
||||||
|
require_version("adam-mini", "To fix: pip install adam-mini")
|
||||||
|
|
||||||
if finetuning_args.plot_loss:
|
if finetuning_args.plot_loss:
|
||||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
@ -365,18 +366,16 @@ def _create_badam_optimizer(
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def _create_adammini_optimizer(
|
|
||||||
|
def _create_adam_mini_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
from adam_mini import Adam_mini
|
from adam_mini import Adam_mini
|
||||||
|
|
||||||
n_embd = model.config.hidden_size
|
hidden_size = getattr(model.config, "hidden_size", None)
|
||||||
n_head = model.config.num_attention_heads
|
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||||
n_query_groups = getattr(model.config, "num_key_value_heads", n_head)
|
num_kv_head = getattr(model.config, "num_key_value_heads", None)
|
||||||
|
|
||||||
print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups)
|
|
||||||
|
|
||||||
optimizer = Adam_mini(
|
optimizer = Adam_mini(
|
||||||
named_parameters=model.named_parameters(),
|
named_parameters=model.named_parameters(),
|
||||||
|
@ -384,14 +383,15 @@ def _create_adammini_optimizer(
|
||||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||||
eps=training_args.adam_epsilon,
|
eps=training_args.adam_epsilon,
|
||||||
weight_decay=training_args.weight_decay,
|
weight_decay=training_args.weight_decay,
|
||||||
model_sharding=False,
|
model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(),
|
||||||
dim=n_embd,
|
dim=hidden_size,
|
||||||
n_heads=n_head,
|
n_heads=num_q_head,
|
||||||
n_kv_heads=n_query_groups,
|
n_kv_heads=num_kv_head,
|
||||||
)
|
)
|
||||||
|
logger.info("Using Adam-mini optimizer.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def create_custom_optimizer(
|
def create_custom_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
@ -406,8 +406,8 @@ def create_custom_optimizer(
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
if finetuning_args.use_adammini:
|
if finetuning_args.use_adam_mini:
|
||||||
return _create_adammini_optimizer(model, training_args, finetuning_args)
|
return _create_adam_mini_optimizer(model, training_args)
|
||||||
|
|
||||||
|
|
||||||
def create_custom_scheduler(
|
def create_custom_scheduler(
|
||||||
|
|
Loading…
Reference in New Issue