use low_cpu_mem_usage to speed up loading

This commit is contained in:
hiyouga 2023-06-03 18:19:01 +08:00
parent dca27b4412
commit 771f454ff1
3 changed files with 24 additions and 12 deletions

View File

@ -13,7 +13,7 @@ def main():
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)

View File

@ -143,15 +143,24 @@ def load_pretrained(
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
padding_side="left",
**config_kwargs
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
@ -162,23 +171,19 @@ def load_pretrained(
config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the model weights are float16 type
low_cpu_mem_usage=True,
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if not is_trainable:
model.requires_grad_(False) # fix all model params
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
@ -194,6 +199,9 @@ def load_pretrained(
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
if not is_trainable:
model.requires_grad_(False) # fix all model params
print_trainable_params(model)
return model, tokenizer

View File

@ -38,13 +38,17 @@ class ModelArguments:
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=True,
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
@ -59,7 +63,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
@ -75,7 +79,7 @@ class ModelArguments:
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging lora weights
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]