diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 55f045f6..85f386de 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -108,7 +108,9 @@ def configure_visual_model(config: "PretrainedConfig") -> None: Patches VLMs before loading them. """ model_type = getattr(config, "model_type", None) - if model_type in ["llava", "video_llava"] or "llava_next" in model_type: # required for ds zero3 and valuehead models + if ( + model_type in ["llava", "video_llava"] or "llava_next" in model_type + ): # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None):