use low_cpu_mem_usage to speed up loading
This commit is contained in:
parent
dca27b4412
commit
771f454ff1
|
@ -13,7 +13,7 @@ def main():
|
|||
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||
|
|
|
@ -143,15 +143,24 @@ def load_pretrained(
|
|||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side="left"
|
||||
padding_side="left",
|
||||
**config_kwargs
|
||||
)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
config_kwargs = {}
|
||||
if model_args.quantization_bit is not None:
|
||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
||||
|
@ -162,23 +171,19 @@ def load_pretrained(
|
|||
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # the model weights are float16 type
|
||||
low_cpu_mem_usage=True,
|
||||
**config_kwargs
|
||||
)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
|
@ -194,6 +199,9 @@ def load_pretrained(
|
|||
if model_args.quantization_bit is not None:
|
||||
model._is_int8_training_enabled = True
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
|
|
@ -38,13 +38,17 @@ class ModelArguments:
|
|||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=True,
|
||||
default=False,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
|
@ -59,7 +63,7 @@ class ModelArguments:
|
|||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
|
@ -75,7 +79,7 @@ class ModelArguments:
|
|||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging lora weights
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue