fix bug in web demo
This commit is contained in:
parent
56eb99106a
commit
a38d57ddd7
|
@ -49,6 +49,8 @@ class LogCallback(TrainerCallback):
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
|
if "loss" not in state.log_history[-1]:
|
||||||
|
return
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
cur_steps = state.log_history[-1].get("step")
|
cur_steps = state.log_history[-1].get("step")
|
||||||
elapsed_time = cur_time - self.start_time
|
elapsed_time = cur_time - self.start_time
|
||||||
|
|
|
@ -12,7 +12,7 @@ from transformers import TextIteratorStreamer
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
|
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
|
|
||||||
|
@ -93,6 +93,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
|
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
|
@ -107,9 +108,9 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
response += new_text
|
response += new_text
|
||||||
history = history + [(input, response)]
|
new_history = history + [(input, response)]
|
||||||
chatbot[-1] = (parse_text(input), parse_text(response))
|
chatbot[-1] = (parse_text(input), parse_text(response))
|
||||||
yield chatbot, history
|
yield chatbot, new_history
|
||||||
|
|
||||||
|
|
||||||
def reset_user_input():
|
def reset_user_input():
|
||||||
|
@ -145,4 +146,4 @@ with gr.Blocks() as demo:
|
||||||
|
|
||||||
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)
|
||||||
|
|
Loading…
Reference in New Issue