fix stop words

This commit is contained in:
hiyouga 2023-12-20 19:06:43 +08:00
parent 5af8841c4f
commit dec360d5ae
1 changed files with 7 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import tiktoken
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@ -223,19 +224,20 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
stop_words = deepcopy(template.stop_words)
if template.replace_eos:
if not template.stop_words:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
tokenizer.eos_token = template.stop_words.pop(0)
tokenizer.eos_token = stop_words.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if template.stop_words:
if stop_words:
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
dict(additional_special_tokens=stop_words),
replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(template.stop_words)))
logger.info("Add {} to stop words.".format(",".join(stop_words)))
return template