update scripts

This commit is contained in:
hiyouga 2024-09-08 14:17:41 +08:00
parent b6681d7198
commit f2aa02c070
6 changed files with 24 additions and 11 deletions

View File

@ -27,7 +27,7 @@ from llamafactory.chat import ChatModel
def calculate_flops(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 256,
seq_length: int = 512,
flash_attn: str = "auto",
):
r"""

View File

@ -39,7 +39,7 @@ def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
@ -48,7 +48,8 @@ def calculate_lr(
):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(

View File

@ -102,8 +102,9 @@ def compute_device_flops() -> float:
def calculate_mfu(
model_name_or_path: str,
batch_size: int,
seq_length: int,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
@ -129,7 +130,7 @@ def calculate_mfu(
"output_dir": os.path.join("saves", "test_mfu"),
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": 100,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]:

View File

@ -60,7 +60,7 @@ def calculate_ppl(
save_name: str,
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
@ -69,7 +69,7 @@ def calculate_ppl(
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(

View File

@ -25,14 +25,14 @@ from llamafactory.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(

View File

@ -21,13 +21,18 @@ from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
@ -153,6 +158,12 @@ class VllmEngine(BaseEngine):
)
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
else:
multi_modal_data = None