tiny fix
This commit is contained in:
parent
3b9eee8cd2
commit
eac9921e5c
|
@ -146,7 +146,7 @@ def load_pretrained(
|
|||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
|
@ -183,7 +183,7 @@ def load_pretrained(
|
|||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
@ -261,6 +261,9 @@ def prepare_args(
|
|||
if training_args.do_predict and (not training_args.predict_with_generate):
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
|
@ -275,11 +278,11 @@ def prepare_args(
|
|||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
finetuning_args.compute_dtype = torch.float16
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
finetuning_args.compute_dtype = torch.bfloat16
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
finetuning_args.compute_dtype = torch.float32
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
|
@ -303,6 +306,9 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
|
|||
else:
|
||||
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
return model_args, data_args, finetuning_args
|
||||
|
||||
|
||||
|
|
|
@ -62,6 +62,10 @@ class ModelArguments:
|
|||
default=True,
|
||||
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
|
@ -208,10 +212,6 @@ class FinetuningArguments:
|
|||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str):
|
||||
|
|
Loading…
Reference in New Issue