fix #4379
This commit is contained in:
parent
095fab58d3
commit
cc016461e6
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -19,6 +21,7 @@ from transformers import PreTrainedModel
|
|||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
|
@ -98,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
|
||||
if finetuning_args.stage == "rm":
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
vhead_path = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
vhead_path = model_args.model_name_or_path
|
||||
|
||||
if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
|
||||
shutil.copy(
|
||||
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||
shutil.copy(
|
||||
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
|
|
Loading…
Reference in New Issue