From 2c17d91bb7ae58346c020c46cb7ffabad4deff4f Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 29 Sep 2024 23:58:09 +0800 Subject: [PATCH] Update common.py --- src/llamafactory/webui/common.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 0ad2929e..d4e9be51 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -115,13 +115,6 @@ def get_model_path(model_name: str) -> str: return model_path -def get_prefix(model_name: str) -> str: - r""" - Gets the prefix of the model name to obtain the model family. - """ - return model_name.split("-")[0] - - def get_model_info(model_name: str) -> Tuple[str, str]: r""" Gets the necessary information of this model. @@ -137,21 +130,14 @@ def get_template(model_name: str) -> str: r""" Gets the template name if the model is a chat model. """ - if ( - model_name - and any(suffix in model_name for suffix in ("-Chat", "-Instruct")) - and get_prefix(model_name) in DEFAULT_TEMPLATE - ): - return DEFAULT_TEMPLATE[get_prefix(model_name)] - - return "default" + return DEFAULT_TEMPLATE.get(model_name, "default") def get_visual(model_name: str) -> bool: r""" Judges if the model is a vision language model. """ - return get_prefix(model_name) in VISION_MODELS + return model_name in VISION_MODELS def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":