37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
|
# Test Template Encode
|
|||
|
# Usage: python .\tests\template_encode.py --model_name_and_path D:\llm\chinese-alpaca-2-7b
|
|||
|
# --template llama2_zh --query 'how are you?'
|
|||
|
# --history '[[\"Hello!\",\"Hi,I am llama2.\"]]'
|
|||
|
|
|||
|
import sys
|
|||
|
import fire
|
|||
|
from typing import List, Optional, Tuple
|
|||
|
from transformers import AutoTokenizer
|
|||
|
sys.path.append("./src")
|
|||
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
|||
|
|
|||
|
|
|||
|
def encode(
|
|||
|
model_name_and_path: str,
|
|||
|
template: str,
|
|||
|
query: str,
|
|||
|
resp: Optional[str] = "",
|
|||
|
history: Optional[List[Tuple[str, str]]] = None,
|
|||
|
system: Optional[str] = None):
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(
|
|||
|
model_name_and_path,
|
|||
|
trust_remote_code=True
|
|||
|
)
|
|||
|
|
|||
|
template = get_template_and_fix_tokenizer(template, tokenizer)
|
|||
|
|
|||
|
encoded_pairs = template.encode_multiturn(tokenizer, query, resp, history, system)
|
|||
|
for prompt_ids, answer_ids in encoded_pairs:
|
|||
|
print("="*50)
|
|||
|
print("prompt_ids: {}, answer_ids: {}".format(prompt_ids, answer_ids))
|
|||
|
print("prompt decode: {}".format(tokenizer.decode(prompt_ids)))
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
fire.Fire(encode)
|