From ed0e186a134de816d6a9278f4e47baa6250a52d1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 21 Jul 2023 13:27:27 +0800 Subject: [PATCH] update web UI, support rm predict #210 --- src/llmtuner/dsets/preprocess.py | 6 ++-- src/llmtuner/tuner/core/parser.py | 2 +- src/llmtuner/tuner/core/trainer.py | 13 ++++--- src/llmtuner/tuner/rm/trainer.py | 30 ++++++++++++++++ src/llmtuner/tuner/rm/workflow.py | 7 ++++ src/llmtuner/webui/components/__init__.py | 5 +-- src/llmtuner/webui/components/chatbot.py | 10 ++---- src/llmtuner/webui/components/export.py | 34 ++++++++++++++++++ src/llmtuner/webui/components/sft.py | 2 +- src/llmtuner/webui/interface.py | 8 +++-- src/llmtuner/webui/locales.py | 44 +++++++++++++++++++++++ src/llmtuner/webui/runner.py | 14 +++++--- src/llmtuner/webui/utils.py | 44 +++++++++++++++++++++-- 13 files changed, 192 insertions(+), 27 deletions(-) create mode 100644 src/llmtuner/webui/components/export.py diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index bf65cc7d..f743e27e 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -143,8 +143,10 @@ def preprocess_dataset( if stage == "pt": preprocess_function = preprocess_pretrain_dataset elif stage == "sft": - preprocess_function = preprocess_unsupervised_dataset \ - if training_args.predict_with_generate else preprocess_supervised_dataset + if not training_args.predict_with_generate: + preprocess_function = preprocess_supervised_dataset + else: + preprocess_function = preprocess_unsupervised_dataset elif stage == "rm": preprocess_function = preprocess_pairwise_dataset elif stage == "ppo": diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 186efeea..31e738f3 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -54,7 +54,7 @@ def get_train_args( assert not (training_args.do_train and training_args.predict_with_generate), \ "`predict_with_generate` cannot be set as True while training." - assert (not training_args.do_predict) or training_args.predict_with_generate, \ + assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ "Please enable `predict_with_generate` to save model predictions." assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index a54d93cb..2a025180 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -4,7 +4,8 @@ from typing import Dict, Optional from transformers import Seq2SeqTrainer from transformers.trainer import TRAINING_ARGS_NAME -from transformers.modeling_utils import unwrap_model +from transformers.modeling_utils import PreTrainedModel, unwrap_model +from peft import PeftModel from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME from llmtuner.extras.logging import get_logger @@ -49,9 +50,9 @@ class PeftTrainer(Seq2SeqTrainer): else: backbone_model = model - if self.finetuning_args.finetuning_type == "lora": + if isinstance(backbone_model, PeftModel): # LoRA tuning backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) - else: # freeze/full tuning + elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning backbone_model.config.use_cache = True backbone_model.save_pretrained( output_dir, @@ -61,6 +62,8 @@ class PeftTrainer(Seq2SeqTrainer): backbone_model.config.use_cache = False if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) + else: + logger.warning("No model to save.") with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: f.write(self.args.to_json_string() + "\n") @@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer): model = unwrap_model(self.model) backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model - if self.finetuning_args.finetuning_type == "lora": - backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter")) + if isinstance(backbone_model, PeftModel): + backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter) if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): model.v_head.load_state_dict({ "summary.weight": getattr(model, "reward_head_weight"), diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 749fa68d..584183c4 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -1,10 +1,17 @@ +import os +import json import torch from typing import Dict, List, Optional, Tuple, Union +from transformers.trainer import PredictionOutput from transformers.modeling_utils import PreTrainedModel +from llmtuner.extras.logging import get_logger from llmtuner.tuner.core.trainer import PeftTrainer +logger = get_logger(__name__) + + class PairwisePeftTrainer(PeftTrainer): r""" Inherits PeftTrainer to compute pairwise loss. @@ -36,3 +43,26 @@ class PairwisePeftTrainer(PeftTrainer): r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() return (loss, [loss, r_accept, r_reject]) if return_outputs else loss + + def save_predictions( + self, + predict_results: PredictionOutput + ) -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + acc_scores, rej_scores = predict_results.predictions + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for acc_score, rej_score in zip(acc_scores, rej_scores): + res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)})) + writer.write("\n".join(res)) diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index c2d7104a..b7022c15 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -56,3 +56,10 @@ def run_rm( metrics = trainer.evaluate(metric_key_prefix="eval") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset, metric_key_prefix="predict") + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py index 779cf390..9312f409 100644 --- a/src/llmtuner/webui/components/__init__.py +++ b/src/llmtuner/webui/components/__init__.py @@ -1,4 +1,5 @@ -from llmtuner.webui.components.eval import create_eval_tab -from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.top import create_top from llmtuner.webui.components.sft import create_sft_tab +from llmtuner.webui.components.eval import create_eval_tab +from llmtuner.webui.components.infer import create_infer_tab +from llmtuner.webui.components.export import create_export_tab diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 7565ba7a..0fa6a3d8 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -22,13 +22,9 @@ def create_chat_box( with gr.Column(scale=1): clear_btn = gr.Button() - max_new_tokens = gr.Slider( - 10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True - ) - top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True) - temperature = gr.Slider( - 0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True - ) + max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1) + top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01) + temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01) history = gr.State([]) diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py new file mode 100644 index 00000000..72b66e71 --- /dev/null +++ b/src/llmtuner/webui/components/export.py @@ -0,0 +1,34 @@ +from typing import Dict +import gradio as gr +from gradio.components import Component + +from llmtuner.webui.utils import export_model + + +def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: + with gr.Row(): + save_dir = gr.Textbox() + max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) + + export_btn = gr.Button() + info_box = gr.Textbox(show_label=False, interactive=False) + + export_btn.click( + export_model, + [ + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + max_shard_size, + save_dir + ], + [info_box] + ) + + return dict( + save_dir=save_dir, + max_shard_size=max_shard_size, + export_btn=export_btn, + info_box=info_box + ) diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/sft.py index bb91a69e..aa2b7a1a 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/sft.py @@ -57,7 +57,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, with gr.Row(): with gr.Column(scale=4): - output_dir = gr.Textbox(interactive=True) + output_dir = gr.Textbox() with gr.Box(): output_box = gr.Markdown() diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 002b6fae..11f1138a 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -5,7 +5,8 @@ from llmtuner.webui.components import ( create_top, create_sft_tab, create_eval_tab, - create_infer_tab + create_infer_tab, + create_export_tab ) from llmtuner.webui.css import CSS from llmtuner.webui.manager import Manager @@ -30,7 +31,10 @@ def create_ui() -> gr.Blocks: with gr.Tab("Chat"): infer_elems = create_infer_tab(top_elems) - elem_list = [top_elems, sft_elems, eval_elems, infer_elems] + with gr.Tab("Export"): + export_elems = create_export_tab(top_elems) + + elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] manager = Manager(elem_list) demo.load( diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 9279fad9..5962d64a 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -452,6 +452,34 @@ LOCALES = { "zh": { "label": "温度系数" } + }, + "save_dir": { + "en": { + "label": "Export dir", + "info": "Directory to save exported model." + }, + "zh": { + "label": "导出目录", + "info": "保存导出模型的文件夹路径。" + } + }, + "max_shard_size": { + "en": { + "label": "Max shard size (GB)", + "info": "The maximum size for a model file." + }, + "zh": { + "label": "最大分块大小(GB)", + "info": "模型文件的最大大小。" + } + }, + "export_btn": { + "en": { + "value": "Export" + }, + "zh": { + "value": "开始导出" + } } } @@ -477,6 +505,14 @@ ALERTS = { "en": "Please choose a dataset.", "zh": "请选择数据集。" }, + "err_no_checkpoint": { + "en": "Please select a checkpoint.", + "zh": "请选择断点。" + }, + "err_no_save_dir": { + "en": "Please provide export dir.", + "zh": "请填写导出目录" + }, "info_aborting": { "en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……" @@ -504,5 +540,13 @@ ALERTS = { "info_unloaded": { "en": "Model unloaded.", "zh": "模型已卸载。" + }, + "info_exporting": { + "en": "Exporting model...", + "zh": "正在导出模型……" + }, + "info_exported": { + "en": "Model exported.", + "zh": "模型导出完成。" } } diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 95b35501..c0ec0787 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -3,7 +3,7 @@ import os import threading import time import transformers -from typing import List, Optional, Tuple +from typing import Generator, List, Optional, Tuple from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE @@ -25,7 +25,9 @@ class Runner: self.aborted = True self.running = False - def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: + def initialize( + self, lang: str, model_name: str, dataset: list + ) -> Tuple[str, str, LoggerHandler, LogCallback]: if self.running: return None, ALERTS["err_conflict"][lang], None, None @@ -50,7 +52,9 @@ class Runner: return model_name_or_path, "", logger_handler, trainer_callback - def finalize(self, lang: str, finish_info: Optional[str] = None) -> str: + def finalize( + self, lang: str, finish_info: Optional[str] = None + ) -> str: self.running = False torch_gc() if self.aborted: @@ -87,7 +91,7 @@ class Runner: lora_dropout: float, lora_target: str, output_dir: str - ): + ) -> Generator[str, None, None]: model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: yield error @@ -174,7 +178,7 @@ class Runner: max_samples: str, batch_size: int, predict: bool - ): + ) -> Generator[str, None, None]: model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: yield error diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index a5b5640f..4921195d 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,11 +3,13 @@ import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import Any, Dict, Tuple +from typing import Any, Dict, Generator, List, Tuple from datetime import datetime from llmtuner.extras.ploting import smooth -from llmtuner.webui.common import get_save_dir, DATA_CONFIG +from llmtuner.tuner import get_infer_args, load_model_and_tokenizer +from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG +from llmtuner.webui.locales import ALERTS def format_info(log: str, tracker: dict) -> str: @@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl ax.set_xlabel("step") ax.set_ylabel("loss") return fig + + +def export_model( + lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str +) -> Generator[str, None, None]: + if not model_name: + yield ALERTS["err_no_model"][lang] + return + + model_name_or_path = get_model_path(model_name) + if not model_name_or_path: + yield ALERTS["err_no_path"][lang] + return + + if not checkpoints: + yield ALERTS["err_no_checkpoint"][lang] + return + + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) + + if not save_dir: + yield ALERTS["err_no_save_dir"][lang] + return + + args = dict( + model_name_or_path=model_name_or_path, + checkpoint_dir=checkpoint_dir, + finetuning_type=finetuning_type + ) + + yield ALERTS["info_exporting"][lang] + model_args, _, finetuning_args, _ = get_infer_args(args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB") + tokenizer.save_pretrained(save_dir) + yield ALERTS["info_exported"][lang]