This commit is contained in:
hiyouga 2023-06-04 12:55:40 +08:00
parent 3b9eee8cd2
commit eac9921e5c
2 changed files with 15 additions and 9 deletions

View File

@ -146,7 +146,7 @@ def load_pretrained(
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = {
"trust_remote_code": True,
@ -183,7 +183,7 @@ def load_pretrained(
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
@ -261,6 +261,9 @@ def prepare_args(
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -275,11 +278,11 @@ def prepare_args(
if model_args.quantization_bit is not None:
if training_args.fp16:
finetuning_args.compute_dtype = torch.float16
model_args.compute_dtype = torch.float16
elif training_args.bf16:
finetuning_args.compute_dtype = torch.bfloat16
model_args.compute_dtype = torch.bfloat16
else:
finetuning_args.compute_dtype = torch.float32
model_args.compute_dtype = torch.float32
# Log on each process the small summary:
logger.info(
@ -303,6 +306,9 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
return model_args, data_args, finetuning_args

View File

@ -62,6 +62,10 @@ class ModelArguments:
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
@ -208,10 +212,6 @@ class FinetuningArguments:
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
def __post_init__(self):
if isinstance(self.lora_target, str):