add e2e tests
This commit is contained in:
parent
1274356263
commit
94d5b1bd8f
|
@ -175,7 +175,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
|
|
|
@ -176,7 +176,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo,identity
|
dataset: mllm_demo,identity # video: mllm_video_demo
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 1024
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
|
|
|
@ -19,7 +19,6 @@ if is_pyav_available():
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
from numpy.typing import NDArray
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
@ -31,11 +30,17 @@ if TYPE_CHECKING:
|
||||||
VideoInput = str
|
VideoInput = str
|
||||||
|
|
||||||
|
|
||||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
def _regularize_images(
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
max_resolution: Optional[int] = None,
|
||||||
|
) -> List["ImageObject"]:
|
||||||
r"""
|
r"""
|
||||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||||
"""
|
"""
|
||||||
image_resolution: int = getattr(processor, "image_resolution", 512)
|
if max_resolution is None:
|
||||||
|
max_resolution: int = getattr(processor, "image_resolution", 512)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for image in images:
|
for image in images:
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
|
@ -49,9 +54,9 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||||
if not isinstance(image, ImageObject):
|
if not isinstance(image, ImageObject):
|
||||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||||
|
|
||||||
if max(image.width, image.height) > image_resolution:
|
if max(image.width, image.height) > max_resolution:
|
||||||
factor = image_resolution / max(image.width, image.height)
|
factor = max_resolution / max(image.width, image.height)
|
||||||
image = image.resize((int(image.width * factor), int(image.height * factor)))
|
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
|
||||||
|
|
||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -61,11 +66,16 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
|
def _regularize_videos(
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> List[List["ImageObject"]]:
|
||||||
r"""
|
r"""
|
||||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||||
"""
|
"""
|
||||||
|
video_resolution: int = getattr(processor, "video_resolution", 128)
|
||||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||||
|
video_maxlen: int = getattr(processor, "video_maxlen", 64)
|
||||||
video_factor: int = getattr(processor, "video_factor", 1)
|
video_factor: int = getattr(processor, "video_factor", 1)
|
||||||
results = []
|
results = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
|
@ -73,6 +83,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
total_frames = video_stream.frames
|
total_frames = video_stream.frames
|
||||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||||
|
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
|
||||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
frames: List["ImageObject"] = []
|
frames: List["ImageObject"] = []
|
||||||
|
@ -81,7 +92,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
|
||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
frames.append(frame.to_image())
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
frames = _regularize_images(frames, processor)
|
frames = _regularize_images(frames, processor, video_resolution)
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -562,8 +562,8 @@ _register_template(
|
||||||
_register_template(
|
_register_template(
|
||||||
name="cpm3",
|
name="cpm3",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -23,12 +23,133 @@ from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class QuantizationArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the quantization method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||||
|
default="bitsandbytes",
|
||||||
|
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||||
|
)
|
||||||
|
quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||||
|
)
|
||||||
|
quantization_type: Literal["fp4", "nf4"] = field(
|
||||||
|
default="nf4",
|
||||||
|
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
|
||||||
|
)
|
||||||
|
double_quantization: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||||
|
)
|
||||||
|
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessorArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the image processor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_resolution: int = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "Keeps the height or width of image below this resolution."},
|
||||||
|
)
|
||||||
|
video_resolution: int = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "Keeps the height or width of video below this resolution."},
|
||||||
|
)
|
||||||
|
video_fps: float = field(
|
||||||
|
default=2.0,
|
||||||
|
metadata={"help": "The frames to sample per second for video inputs."},
|
||||||
|
)
|
||||||
|
video_maxlen: int = field(
|
||||||
|
default=64,
|
||||||
|
metadata={"help": "The maximum number of sampled frames for video inputs."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExportArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the model export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
|
)
|
||||||
|
export_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
|
)
|
||||||
|
export_device: Literal["cpu", "auto"] = field(
|
||||||
|
default="cpu",
|
||||||
|
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||||
|
)
|
||||||
|
export_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the exported model."},
|
||||||
|
)
|
||||||
|
export_quantization_dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
|
)
|
||||||
|
export_quantization_nsamples: int = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "The number of samples used for quantization."},
|
||||||
|
)
|
||||||
|
export_quantization_maxlen: int = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||||
|
)
|
||||||
|
export_legacy_format: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
|
)
|
||||||
|
export_hub_model_id: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VllmArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the vLLM worker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm_maxlen: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||||
|
)
|
||||||
|
vllm_gpu_util: float = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
||||||
|
)
|
||||||
|
vllm_enforce_eager: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||||
|
)
|
||||||
|
vllm_max_lora_rank: int = field(
|
||||||
|
default=32,
|
||||||
|
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||||
},
|
},
|
||||||
|
@ -74,26 +195,6 @@ class ModelArguments:
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||||
)
|
)
|
||||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
|
||||||
default="bitsandbytes",
|
|
||||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
|
||||||
)
|
|
||||||
quantization_bit: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
|
||||||
)
|
|
||||||
quantization_type: Literal["fp4", "nf4"] = field(
|
|
||||||
default="nf4",
|
|
||||||
metadata={"help": "Quantization data type to use in int4 training."},
|
|
||||||
)
|
|
||||||
double_quantization: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
|
||||||
)
|
|
||||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
|
||||||
)
|
|
||||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
|
@ -138,34 +239,10 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||||
)
|
)
|
||||||
image_resolution: int = field(
|
|
||||||
default=512,
|
|
||||||
metadata={"help": "Keeps the height or width of image below this resolution."},
|
|
||||||
)
|
|
||||||
video_fps: float = field(
|
|
||||||
default=2.0,
|
|
||||||
metadata={"help": "The frames to sample per second for video training."},
|
|
||||||
)
|
|
||||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||||
default="huggingface",
|
default="huggingface",
|
||||||
metadata={"help": "Backend engine used at inference."},
|
metadata={"help": "Backend engine used at inference."},
|
||||||
)
|
)
|
||||||
vllm_maxlen: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
|
||||||
)
|
|
||||||
vllm_gpu_util: float = field(
|
|
||||||
default=0.9,
|
|
||||||
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
|
||||||
)
|
|
||||||
vllm_enforce_eager: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
|
||||||
)
|
|
||||||
vllm_max_lora_rank: int = field(
|
|
||||||
default=32,
|
|
||||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
|
||||||
)
|
|
||||||
offload_folder: str = field(
|
offload_folder: str = field(
|
||||||
default="offload",
|
default="offload",
|
||||||
metadata={"help": "Path to offload model weights."},
|
metadata={"help": "Path to offload model weights."},
|
||||||
|
@ -186,42 +263,6 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
)
|
)
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."},
|
|
||||||
)
|
|
||||||
export_size: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
|
||||||
)
|
|
||||||
export_device: Literal["cpu", "auto"] = field(
|
|
||||||
default="cpu",
|
|
||||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
|
||||||
)
|
|
||||||
export_quantization_bit: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The number of bits to quantize the exported model."},
|
|
||||||
)
|
|
||||||
export_quantization_dataset: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
|
||||||
)
|
|
||||||
export_quantization_nsamples: int = field(
|
|
||||||
default=128,
|
|
||||||
metadata={"help": "The number of samples used for quantization."},
|
|
||||||
)
|
|
||||||
export_quantization_maxlen: int = field(
|
|
||||||
default=1024,
|
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
|
||||||
)
|
|
||||||
export_legacy_format: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
|
||||||
)
|
|
||||||
export_hub_model_id: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
|
||||||
)
|
|
||||||
print_param_status: bool = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
|
@ -248,6 +289,9 @@ class ModelArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.model_name_or_path is None:
|
||||||
|
raise ValueError("Please provide `model_name_or_path`.")
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||||
|
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
setattr(processor, "video_fps", model_args.video_fps)
|
||||||
|
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||||
if getattr(config, "model_type", None) == "qwen2_vl":
|
if getattr(config, "model_type", None) == "qwen2_vl":
|
||||||
setattr(processor, "video_factor", 2)
|
setattr(processor, "video_factor", 2)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from llamafactory.chat import ChatModel
|
||||||
|
|
||||||
|
|
||||||
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
"do_sample": False,
|
||||||
|
"max_new_tokens": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
MESSAGES = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_RESPONSE = "_rho"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat():
|
||||||
|
chat_model = ChatModel(INFER_ARGS)
|
||||||
|
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat():
|
||||||
|
chat_model = ChatModel(INFER_ARGS)
|
||||||
|
response = ""
|
||||||
|
for token in chat_model.stream_chat(MESSAGES):
|
||||||
|
response += token
|
||||||
|
|
||||||
|
assert response == EXPECTED_RESPONSE
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.train.tuner import export_model, run_exp
|
||||||
|
|
||||||
|
|
||||||
|
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||||
|
|
||||||
|
TRAIN_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"do_train": True,
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||||
|
"template": "llama3",
|
||||||
|
"cutoff_len": 1,
|
||||||
|
"overwrite_cache": True,
|
||||||
|
"overwrite_output_dir": True,
|
||||||
|
"per_device_train_batch_size": 1,
|
||||||
|
"max_steps": 1,
|
||||||
|
"fp16": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
"export_dir": "llama3_export",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"stage,dataset",
|
||||||
|
[
|
||||||
|
("pt", "c4_demo"),
|
||||||
|
("sft", "alpaca_en_demo"),
|
||||||
|
("rm", "dpo_en_demo"),
|
||||||
|
("dpo", "dpo_en_demo"),
|
||||||
|
("kto", "kto_en_demo"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_train(stage: str, dataset: str):
|
||||||
|
output_dir = "train_{}".format(stage)
|
||||||
|
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||||
|
assert os.path.exists(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
export_model(INFER_ARGS)
|
||||||
|
assert os.path.exists("llama3_export")
|
Loading…
Reference in New Issue