diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 69cccd93..e1015821 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead @@ -175,6 +176,10 @@ def load_model( if not is_trainable: model.requires_grad_(False) + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + model.eval() else: model.train()