From 87f8f830e20aa839e089559c1d038954742000ef Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 3 Aug 2023 15:53:32 +0800 Subject: [PATCH] support Qwen-7B, fix InternLM-7B inference --- README.md | 5 +++ README_zh.md | 7 ++- requirements.txt | 1 + src/llmtuner/chat/stream_chat.py | 9 +++- src/llmtuner/extras/misc.py | 19 +++++++- src/llmtuner/extras/template.py | 60 ++++++++++++++++++------- src/llmtuner/hparams/finetuning_args.py | 11 +++-- src/llmtuner/tuner/core/loader.py | 2 +- 8 files changed, 89 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 9758079c..b546a841 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ ## Changelog +[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model. + [23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. [23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. @@ -46,6 +48,7 @@ - [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) - [InternLM](https://github.com/InternLM/InternLM) (7B) +- [Qwen](https://github.com/QwenLM/Qwen-7B) (7B) ## Supported Training Approaches @@ -111,6 +114,7 @@ huggingface-cli login - Python 3.8+ and PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT and TRL +- sentencepiece and tiktoken - jieba, rouge-chinese and nltk (used at evaluation) - gradio and matplotlib (used in web_demo.py) - uvicorn, fastapi and sse-starlette (used in api_demo.py) @@ -378,6 +382,7 @@ Please follow the model licenses to use the corresponding model weights: - [Falcon](LICENSE) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [InternLM](https://github.com/InternLM/InternLM#open-source-license) +- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) ## Citation diff --git a/README_zh.md b/README_zh.md index ae2ad81b..c8f11293 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,6 +12,8 @@ ## 更新日志 +[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。 + [23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。 @@ -20,7 +22,7 @@ [23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 -[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 +[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 [23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 @@ -46,6 +48,7 @@ - [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) - [InternLM](https://github.com/InternLM/InternLM) (7B) +- [Qwen](https://github.com/QwenLM/Qwen-7B) (7B) ## 微调方法 @@ -111,6 +114,7 @@ huggingface-cli login - Python 3.8+ 和 PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL +- sentencepiece 和 tiktoken - jieba, rouge-chinese 和 nltk (用于评估) - gradio 和 matplotlib (用于网页端交互) - uvicorn, fastapi 和 sse-starlette (用于 API) @@ -378,6 +382,7 @@ python src/export_model.py \ - [Falcon](LICENSE) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [InternLM](https://github.com/InternLM/InternLM#open-source-license) +- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) ## 引用 diff --git a/requirements.txt b/requirements.txt index fb5820ab..d99ce326 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ accelerate>=0.21.0 peft>=0.4.0 trl>=0.4.7 sentencepiece +tiktoken jieba rouge-chinese nltk diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index d5a5f1ad..b1ded67a 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor +from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria from llmtuner.extras.template import get_template from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer @@ -16,6 +16,10 @@ class ChatModel: self.model = dispatch_model(self.model) self.template = get_template(data_args.template) self.source_prefix = data_args.source_prefix + self.stop_ids = [ + self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words + ] + self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) def process_args( self, @@ -47,7 +51,8 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], - logits_processor=get_logits_processor() + logits_processor=get_logits_processor(), + stopping_criteria=get_stopwords_criteria(self.stop_ids) )) if max_length: diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 93b65aa6..766de40d 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,8 +1,7 @@ import torch from typing import TYPE_CHECKING, List, Optional, Tuple -from transformers.generation.utils import LogitsProcessorList -from transformers.generation.logits_process import LogitsProcessor +from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList from llmtuner.extras.constants import LAYERNORM_NAMES @@ -46,6 +45,22 @@ def get_logits_processor() -> LogitsProcessorList: return logits_processor +class StopWordsCriteria(StoppingCriteria): + + def __init__(self, stop_ids: List[int]) -> None: + super().__init__() + self.stop_ids = stop_ids + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids]) + + +def get_stopwords_criteria(stop_ids: List[int]) -> StoppingCriteriaList: + stopwords_criteria = StoppingCriteriaList() + stopwords_criteria.append(StopWordsCriteria(stop_ids)) + return stopwords_criteria + + def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index cc3f2b1d..bb550058 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -9,6 +9,7 @@ class Template: prompt: str sep: str use_history: bool + stop_words: List[str] def get_prompt( self, @@ -74,13 +75,16 @@ class Llama2Template(Template): templates: Dict[str, Template] = {} -def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: +def register_template( + name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str] +) -> None: template_class = Llama2Template if name == "llama2" else Template templates[name] = template_class( prefix=prefix, prompt=prompt, sep=sep, - use_history=use_history + use_history=use_history, + stop_words=stop_words ) @@ -98,7 +102,8 @@ register_template( prefix="", prompt="{query}", sep="", - use_history=False + use_history=False, + stop_words=[] ) @@ -111,7 +116,8 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions.", prompt="Human: {query}\nAssistant: ", sep="\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -132,7 +138,8 @@ register_template( "If you don't know the answer to a question, please don't share false information.\n<>\n\n", prompt="[INST] {query} [/INST] ", sep="", - use_history=True + use_history=True, + stop_words=[] ) @@ -146,7 +153,8 @@ register_template( "Write a response that appropriately completes the request.", prompt="### Instruction:\n{query}\n\n### Response:\n", sep="\n\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -160,7 +168,8 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions.", prompt="USER: {query} ASSISTANT: ", sep="", - use_history=True + use_history=True, + stop_words=[] ) @@ -172,7 +181,8 @@ register_template( prefix="", prompt="Human: {query}\n\nBelle: ", sep="\n\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -184,7 +194,8 @@ register_template( prefix="", prompt="User: {query}\nBot: ", sep="\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -196,7 +207,8 @@ register_template( prefix="", prompt="Human: {query}\nAssistant: ", sep="\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -208,7 +220,8 @@ register_template( prefix="", prompt=":{query}\n:", sep="\n", - use_history=True + use_history=True, + stop_words=[] ) @@ -221,7 +234,8 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the human's questions.", prompt="Human: {query}###Assistant: ", sep="###", - use_history=True + use_history=True, + stop_words=[] ) @@ -233,7 +247,8 @@ register_template( prefix="", prompt="<|User|>:{query}\n<|Bot|>:", sep="\n", - use_history=True + use_history=True, + stop_words=[""] ) @@ -245,7 +260,8 @@ register_template( prefix="", prompt="{query}", sep="", - use_history=True + use_history=True, + stop_words=[] ) @@ -258,5 +274,19 @@ register_template( prefix="<|system|>\n", prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", sep="<|end|>\n", - use_history=True + use_history=True, + stop_words=["<|end|>"] +) + + +r""" +Supports: https://huggingface.co/Qwen/Qwen-7B-Chat +""" +register_template( + name="chatml", + prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", + prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", + sep="<|im_end|>\n", + use_history=True, + stop_words=["<|im_end|>"] ) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index f43e5786..1a7d5860 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -19,7 +19,8 @@ class FinetuningArguments: LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \ Falcon choices: [\"32\", \"60\"], \ - Baichuan choices: [\"32\", \"40\"]"} + Baichuan choices: [\"32\", \"40\"] \ + Qwen choices: [\"32\"]"} ) num_layer_trainable: Optional[int] = field( default=3, @@ -30,7 +31,8 @@ class FinetuningArguments: metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ - Baichuan choices: [\"mlp\", \"self_attn\"]"} + Baichuan choices: [\"mlp\", \"self_attn\"], \ + Qwen choices: [\"attn\", \"mlp\"]"} ) lora_rank: Optional[int] = field( default=8, @@ -47,9 +49,10 @@ class FinetuningArguments: lora_target: Optional[str] = field( default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ - LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + LLaMA & LLaMA-2 & InternLM choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ - Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} + Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + Qwen choices: [\"c_attn\", \"c_proj\", \"w1\", \"w2\"]"} ) def __post_init__(self): diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 646a0509..ee33218c 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -67,7 +67,7 @@ def load_model_and_tokenizer( **config_kwargs ) if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version) - tokenizer.pad_token_id = 0 # set as the token + tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True