diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 8cdc3511..e85e0f35 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -53,6 +53,22 @@ class Template: return convs +class Llama2Template(Template): + def _format_example(self, query, history, prefix): + sys = prefix or self.prefix + if not sys.startswith("<>\n"): + sys = f"<>\n{sys.strip()}\n<>\n\n" + history = history if (history and self.use_history) else [] + history = history + [(query, "")] + convs = [] + for turn_idx, (query_i, resp_i) in enumerate(history): + if turn_idx == 0: + convs.append([self.prompt.format(query=sys+query_i), resp_i]) + else: + convs.append([(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]) + return convs + + templates: Dict[str, Template] = {} @@ -101,8 +117,7 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf https://huggingface.co/meta-llama/Llama-2-13b-chat-hf https://huggingface.co/meta-llama/Llama-2-70b-chat-hf """ -register_template( - name="llama2", +templates["llama2"] = Llama2Template( prefix="<>\nYou are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, " @@ -111,8 +126,8 @@ register_template( "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information.\n<>\n\n", - prompt=" [INST] {query} [/INST] ", - sep="", + prompt="[INST]{query}[/INST]", + sep="", use_history=True )