optionally replace jinja template

This commit is contained in:
hiyouga 2024-09-25 23:02:02 +08:00
parent 5eb871cbf4
commit ba52103ba7
1 changed files with 12 additions and 4 deletions

View File

@ -49,6 +49,7 @@ class Template:
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
@ -214,6 +215,7 @@ def _register_template(
stop_words: Sequence[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None:
r"""
@ -263,6 +265,7 @@ def _register_template(
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin,
)
@ -398,10 +401,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
if template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template
@ -664,6 +668,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
replace_jinja_template=False,
)
@ -740,6 +745,7 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
)
@ -831,6 +837,7 @@ _register_template(
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
@ -843,6 +850,7 @@ _register_template(
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)