This commit is contained in:
hiyouga 2024-06-25 02:31:44 +08:00
parent 095fab58d3
commit cc016461e6
1 changed files with 22 additions and 0 deletions

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
@ -19,6 +21,7 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras.callbacks import LogCallback
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
@ -98,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
safe_serialization=(not model_args.export_legacy_format),
)
if finetuning_args.stage == "rm":
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
)
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
)
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"