fix bug in web demo

This commit is contained in:
hiyouga 2023-06-05 17:58:29 +08:00
parent 56eb99106a
commit a38d57ddd7
2 changed files with 7 additions and 4 deletions

View File

@ -49,6 +49,8 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if "loss" not in state.log_history[-1]:
return
cur_time = time.time() cur_time = time.time()
cur_steps = state.log_history[-1].get("step") cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time elapsed_time = cur_time - self.start_time

View File

@ -12,7 +12,7 @@ from transformers import TextIteratorStreamer
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args = prepare_infer_args() model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
@ -93,6 +93,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"] input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"input_ids": input_ids,
"do_sample": True, "do_sample": True,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
@ -107,9 +108,9 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
response = "" response = ""
for new_text in streamer: for new_text in streamer:
response += new_text response += new_text
history = history + [(input, response)] new_history = history + [(input, response)]
chatbot[-1] = (parse_text(input), parse_text(response)) chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history yield chatbot, new_history
def reset_user_input(): def reset_user_input():
@ -145,4 +146,4 @@ with gr.Blocks() as demo:
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True) demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)