diff --git a/src/utils/common.py b/src/utils/common.py index 57143195..90828403 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -146,7 +146,7 @@ def load_pretrained( finetuning_args = FinetuningArguments(finetuning_type="none") assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ - "RM and PPO training can only be performed with LoRA method." + "RM and PPO training can only be performed with the LoRA method." config_kwargs = { "trust_remote_code": True, @@ -183,7 +183,7 @@ def load_pretrained( config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_compute_dtype=finetuning_args.compute_dtype, + bnb_4bit_compute_dtype=model_args.compute_dtype, bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_quant_type=model_args.quantization_type ) @@ -261,6 +261,9 @@ def prepare_args( if training_args.do_predict and (not training_args.predict_with_generate): raise ValueError("Please enable `predict_with_generate` to save model predictions.") + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -275,11 +278,11 @@ def prepare_args( if model_args.quantization_bit is not None: if training_args.fp16: - finetuning_args.compute_dtype = torch.float16 + model_args.compute_dtype = torch.float16 elif training_args.bf16: - finetuning_args.compute_dtype = torch.bfloat16 + model_args.compute_dtype = torch.bfloat16 else: - finetuning_args.compute_dtype = torch.float32 + model_args.compute_dtype = torch.float32 # Log on each process the small summary: logger.info( @@ -303,6 +306,9 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun else: model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + return model_args, data_args, finetuning_args diff --git a/src/utils/config.py b/src/utils/config.py index be03478d..c8747be0 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -62,6 +62,10 @@ class ModelArguments: default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + metadata={"help": "Used in quantization configs. Do not specify this argument manually."} + ) checkpoint_dir: Optional[str] = field( default=None, metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} @@ -208,10 +212,6 @@ class FinetuningArguments: LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} ) - compute_dtype: Optional[torch.dtype] = field( - default=None, - metadata={"help": "Used in quantization configs. Do not specify this argument manually."} - ) def __post_init__(self): if isinstance(self.lora_target, str):