This commit is contained in:
BUAADreamer 2024-09-29 17:55:40 +08:00
parent 1d09d592d3
commit d5c69400cd
4 changed files with 96 additions and 113 deletions

View File

@ -61,7 +61,6 @@ extra_require = {
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "dev": ["ruff", "pytest"],
"av": ["av>=13.0.0"],
} }

View File

@ -275,28 +275,23 @@ class LlavaNextPlugin(BasePlugin):
self._validate_input(images, videos) self._validate_input(images, videos)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: mm_inputs = self._get_mm_inputs(images, videos, processor)
for message in messages: if "image_sizes" in mm_inputs:
content = message["content"]
while self.image_token in content:
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1)
else:
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_sizes = iter(mm_inputs["image_sizes"]) image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while self.image_token in content: while self.image_token in content:
image_size = next(image_sizes) image_size = next(image_sizes)
orig_height, orig_width = image_size orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default": if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1 image_seqlen -= 1
num_image_tokens += 1 num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message['content'] = content.replace("{{image}}", self.image_token) message['content'] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
@ -316,6 +311,7 @@ class LlavaNextPlugin(BasePlugin):
res = self._get_mm_inputs(images, videos, processor) res = self._get_mm_inputs(images, videos, processor)
return res return res
class LlavaNextVideoPlugin(BasePlugin): class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -329,47 +325,37 @@ class LlavaNextVideoPlugin(BasePlugin):
num_image_tokens = 0 num_image_tokens = 0
num_video_tokens = 0 num_video_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while self.image_token in content: while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1 num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1) content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message['content'] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while self.video_token in content: while self.video_token in content:
num_video_tokens += 1 num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1) content = content.replace(self.video_token, "{{video}}", 1)
else: message['content'] = content.replace("{{video}}", self.video_token * video_seqlen)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message['content'] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
message['content'] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
@ -380,36 +366,38 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages return messages
@override @override
def get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
video_processor = getattr(processor, "video_processor") video_processor = getattr(processor, "video_processor")
res = self._get_mm_inputs(images, [], processor) res = super()._get_mm_inputs(images, [], processor)
if len(videos) != 0: if len(videos) != 0:
videos = self._regularize_videos(videos) videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "image_resolution", 168),
video_fps=getattr(processor, "video_fps", 1.0),
video_maxlen=getattr(processor, "video_maxlen", 16),
)
video_res = video_processor(videos, return_tensors="pt") video_res = video_processor(videos, return_tensors="pt")
res.update(video_res) res.update(video_res)
return res return res
@override @override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: def get_mm_inputs(
r""" self,
Regularizes videos to avoid error. Including reading, resizing and converting. images: Sequence["ImageInput"],
""" videos: Sequence["VideoInput"],
videos = super()._regularize_videos( imglens: Sequence[int],
videos, vidlens: Sequence[int],
image_resolution=168, seqlens: Sequence[int],
video_fps=1.0, processor: Optional["ProcessorMixin"],
video_maxlen=16, ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) self._validate_input(images, videos)
return videos return self._get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@ -579,7 +567,22 @@ class VideoLlavaPlugin(BasePlugin):
num_image_tokens = 0 num_image_tokens = 0
num_video_tokens = 0 num_video_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
exist_images = "pixel_values_images" in mm_inputs
exist_videos = "pixel_values_videos" in mm_inputs
if exist_videos or exist_images:
if exist_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if exist_videos:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while self.image_token in content: while self.image_token in content:
@ -588,39 +591,15 @@ class VideoLlavaPlugin(BasePlugin):
while self.video_token in content: while self.video_token in content:
num_video_tokens += 1 num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1) content = content.replace(self.video_token, "{{video}}", 1)
else:
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values_images" in mm_inputs.keys():
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if "pixel_values_videos" in mm_inputs.keys(): content = content.replace("{{image}}", self.image_token * image_seqlen)
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = num_image_tokens * num_frames
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
for message in messages:
content = message["content"]
while self.image_token in content:
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1)
while self.video_token in content:
num_image_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
return messages return messages
@ -637,19 +616,6 @@ class VideoLlavaPlugin(BasePlugin):
self._validate_input(images, videos) self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, processor)
@override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
videos = super()._regularize_videos(
videos,
image_resolution=224,
video_fps=1.0,
video_maxlen=8,
)
return videos
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,

View File

@ -25,7 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen from .model_utils.visual import get_image_seqlen, get_patch_size, get_vision_feature_select_strategy
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@ -100,9 +100,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution) setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
except Exception: except Exception:
processor = None processor = None

View File

@ -161,6 +161,22 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
return image_seqlen return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", 14)
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
return vision_feature_select_strategy
def patch_target_modules( def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: ) -> Union[str, List[str]]: