merge data part to the text stream

This commit is contained in:
BUAADreamer 2024-04-25 19:58:47 +08:00
parent 2ee3046eb0
commit 42c90c8183
6 changed files with 18 additions and 38 deletions

View File

@ -414,9 +414,6 @@
}, },
"folder": "python" "folder": "python"
}, },
"llava_instruct": {
"hf_hub_url": "HuggingFaceH4/llava-instruct-mix-vsft"
},
"mllm_instruct_example": { "mllm_instruct_example": {
"file_name": "llava_instruct_example.json", "file_name": "llava_instruct_example.json",
"formatting": "llava", "formatting": "llava",

View File

@ -1,25 +0,0 @@
---
dataset_info:
features:
- name: messages
list:
- name: content
list:
- name: index
dtype: int64
- name: text
dtype: string
- name: type
dtype: string
- name: role
dtype: string
- name: images
sequence: image
configs:
- config_name: default
data_files:
- split: train
path: data/train-*
- split: test
path: data/test-*
---

View File

@ -6,22 +6,23 @@ from datasets import load_dataset
from peft import PeftModel from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
import shutil import shutil
from PIL import Image
"""usage """usage
python3 scripts/test_mllm.py \ python3 scripts/test_mllm.py \
--base_model_path llava-hf/llava-1.5-7b-hf \ --base_model_path llava-hf/llava-1.5-7b-hf \
--lora_model_path saves/llava-1.5-7b/lora/sft \ --lora_model_path saves/llava-1.5-7b/lora/sft \
--model_path saves/llava-1.5-7b/lora/merged \ --model_path saves/llava-1.5-7b/lora/merged \
--dataset_name data/mllm_example_dataset \ --dataset_name data/llava_instruct_example.json \
--do_merge 1 --do_merge 1
""" """
def get_processor(model_path): def get_processor(model_path):
CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" processor = AutoProcessor.from_pretrained(model_path)
CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }} ASSISTANT: {% else %}{{ message['content'] }}{% endif %} {% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
tokenizer.chat_template = CHAT_TEMPLATE tokenizer.chat_template = CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_path)
processor.tokenizer = tokenizer processor.tokenizer = tokenizer
return processor return processor
@ -69,7 +70,7 @@ def main(
device_map="cuda", device_map="cuda",
) )
processor = get_processor(model_path) processor = get_processor(model_path)
raw_datasets = load_dataset(dataset_name) raw_datasets = load_dataset("json", data_files=dataset_name)
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
examples = train_dataset.select(range(3)) examples = train_dataset.select(range(3))
texts = [] texts = []
@ -80,11 +81,18 @@ def main(
messages, tokenize=False, add_generation_prompt=False messages, tokenize=False, add_generation_prompt=False
) )
texts.append(text) texts.append(text)
images.append(example["images"][0]) images.append(Image.open(example["images"][0]))
batch = processor(texts, images, return_tensors="pt", padding=True).to("cuda") batch = processor(text=texts, images=images, return_tensors="pt", padding=True).to(
"cuda"
)
output = model.generate(**batch, max_new_tokens=100) output = model.generate(**batch, max_new_tokens=100)
res = processor.batch_decode(output, skip_special_tokens=True) res_list = processor.batch_decode(output, skip_special_tokens=True)
print(res) for i, prompt in enumerate(texts):
res = res_list[i]
print(f"#{i}")
print(f"prompt:{prompt}")
print(f"response:{res[len(prompt):].strip()}")
print()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1012,8 +1012,8 @@ _register_template(
_register_template( _register_template(
name="llava", name="llava",
format_user=StringFormatter(slots=["USER: {{content}} "]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]),
format_assistant=StringFormatter(slots=["ASSISTANT: {{content}}"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
default_system=( default_system=(
"A chat between a curious user and an artificial intelligence assistant. " "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."