fix ci
This commit is contained in:
parent
e7f6a9a925
commit
b8e616183c
|
@ -19,6 +19,7 @@ import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
|
from llamafactory.data.template import _get_jinja_template
|
||||||
from llamafactory.hparams import DataArguments
|
from llamafactory.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +118,8 @@ def test_encode_multiturn(use_fast: bool):
|
||||||
def test_jinja_template(use_fast: bool):
|
def test_jinja_template(use_fast: bool):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||||
get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace
|
||||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue