LLaMA-Factory/scripts/cal_lr.py

77 lines
2.8 KiB
Python
Raw Normal View History

2023-11-14 12:58:37 +00:00
# coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
2024-05-04 15:05:17 +00:00
from typing import Literal
2024-01-20 12:15:56 +00:00
import fire
import torch
2023-11-14 12:58:37 +00:00
from torch.utils.data import DataLoader
2024-01-20 12:15:56 +00:00
from tqdm import tqdm
2024-02-18 18:09:13 +00:00
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
2023-11-14 12:58:37 +00:00
2024-05-16 10:39:08 +00:00
from llamafactory.data import get_dataset
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
2023-11-14 12:58:37 +00:00
2024-01-20 12:15:56 +00:00
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
2023-11-14 12:58:37 +00:00
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
2024-05-04 15:05:17 +00:00
stage: Literal["pt", "sft"] = "sft",
2024-05-04 14:02:25 +00:00
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
2023-11-14 12:58:37 +00:00
):
2024-04-03 10:14:24 +00:00
model_args, data_args, training_args, _, _ = get_train_args(
2024-01-20 12:15:56 +00:00
dict(
2024-02-18 18:09:13 +00:00
stage=stage,
2024-01-20 12:15:56 +00:00
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
2024-02-18 18:09:13 +00:00
template=template,
2024-01-20 12:15:56 +00:00
cutoff_len=cutoff_len,
output_dir="dummy_dir",
2024-02-18 18:09:13 +00:00
overwrite_cache=True,
2024-01-20 12:15:56 +00:00
)
)
2024-04-25 21:44:30 +00:00
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
2024-02-18 18:09:13 +00:00
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError
2024-05-04 14:02:25 +00:00
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
2023-11-14 12:58:37 +00:00
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
2024-01-20 12:15:56 +00:00
batch_max_len = cutoff_len * batch_size # max tokens in a batch
2023-11-14 12:58:37 +00:00
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
2024-01-20 12:15:56 +00:00
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
2023-11-14 13:09:30 +00:00
lr = lr / 6.0 if is_mistral else lr
2024-01-20 12:15:56 +00:00
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
)
)
2023-11-14 12:58:37 +00:00
if __name__ == "__main__":
fire.Fire(calculate_lr)