From d2f18197e379601a60fa878af975c68d7c8b9648 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 21 Jul 2023 14:09:07 +0800 Subject: [PATCH] fix save function --- src/llmtuner/extras/save_and_load.py | 6 +++--- src/llmtuner/tuner/core/trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/extras/save_and_load.py b/src/llmtuner/extras/save_and_load.py index fd4a8165..781b9bb7 100644 --- a/src/llmtuner/extras/save_and_load.py +++ b/src/llmtuner/extras/save_and_load.py @@ -1,6 +1,6 @@ import os import torch -from typing import Dict +from typing import Dict, Optional from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.modeling_utils import load_sharded_checkpoint @@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger logger = get_logger(__name__) -def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters +def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]: state_dict = model.state_dict() filtered_state_dict = {} for k, v in model.named_parameters(): - if v.requires_grad: + if (not trainable_only) or v.requires_grad: filtered_state_dict[k] = state_dict[k].cpu().clone().detach() return filtered_state_dict diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index 2a025180..c9bb7043 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -56,7 +56,7 @@ class PeftTrainer(Seq2SeqTrainer): backbone_model.config.use_cache = True backbone_model.save_pretrained( output_dir, - state_dict=get_state_dict(backbone_model), + state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")), safe_serialization=self.args.save_safetensors ) backbone_model.config.use_cache = False