fix tests
This commit is contained in:
parent
d5c69400cd
commit
97d1536ee1
|
@ -139,7 +139,14 @@ def test_llava_next_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||||
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
|
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
|
||||||
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
|
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
image_seqlen = 1176
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{
|
||||||
|
key: value.replace("<image>", "<image>" * image_seqlen)
|
||||||
|
for key, value in message.items()
|
||||||
|
}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
@ -148,7 +155,14 @@ def test_llava_next_video_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||||
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
|
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
|
||||||
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
|
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
image_seqlen = 1176
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{
|
||||||
|
key: value.replace("<image>", "<image>" * image_seqlen)
|
||||||
|
for key, value in message.items()
|
||||||
|
}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
@ -190,6 +204,13 @@ def test_video_llava_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
||||||
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
|
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
|
||||||
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
image_seqlen = 256
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{
|
||||||
|
key: value.replace("<image>", "<image>" * image_seqlen)
|
||||||
|
for key, value in message.items()
|
||||||
|
}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
Loading…
Reference in New Issue