use low_cpu_mem_usage to speed up loading
This commit is contained in:
parent
dca27b4412
commit
771f454ff1
|
@ -13,7 +13,7 @@ def main():
|
||||||
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
|
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||||
|
|
|
@ -143,15 +143,24 @@ def load_pretrained(
|
||||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||||
"RM and PPO training can only be performed with LoRA method."
|
"RM and PPO training can only be performed with LoRA method."
|
||||||
|
|
||||||
|
config_kwargs = {
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"cache_dir": model_args.cache_dir,
|
||||||
|
"revision": model_args.model_revision,
|
||||||
|
"use_auth_token": True if model_args.use_auth_token else None,
|
||||||
|
}
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
padding_side="left"
|
padding_side="left",
|
||||||
|
**config_kwargs
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
config_kwargs = {}
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
||||||
|
@ -162,23 +171,19 @@ def load_pretrained(
|
||||||
|
|
||||||
config_kwargs["load_in_8bit"] = True
|
config_kwargs["load_in_8bit"] = True
|
||||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
# Load and prepare pretrained models (without valuehead).
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=torch.float16, # the model weights are float16 type
|
torch_dtype=torch.float16, # the model weights are float16 type
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
model = prepare_model_for_training(model) if is_trainable else model
|
model = prepare_model_for_training(model) if is_trainable else model
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||||
|
|
||||||
if not is_trainable:
|
|
||||||
model.requires_grad_(False) # fix all model params
|
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
if stage == "rm" or stage == "ppo": # add value head
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
|
||||||
|
@ -194,6 +199,9 @@ def load_pretrained(
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
model._is_int8_training_enabled = True
|
model._is_int8_training_enabled = True
|
||||||
|
|
||||||
|
if not is_trainable:
|
||||||
|
model.requires_grad_(False) # fix all model params
|
||||||
|
|
||||||
print_trainable_params(model)
|
print_trainable_params(model)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
|
@ -38,13 +38,17 @@ class ModelArguments:
|
||||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
use_fast_tokenizer: Optional[bool] = field(
|
||||||
default=True,
|
default=False,
|
||||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||||
)
|
)
|
||||||
use_auth_token: Optional[bool] = field(
|
use_auth_token: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||||
)
|
)
|
||||||
|
model_revision: Optional[str] = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||||
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model."}
|
metadata={"help": "The number of bits to quantize the model."}
|
||||||
|
@ -59,7 +63,7 @@ class ModelArguments:
|
||||||
)
|
)
|
||||||
checkpoint_dir: Optional[str] = field(
|
checkpoint_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -75,7 +79,7 @@ class ModelArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.checkpoint_dir is not None: # support merging lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue