fix some
This commit is contained in:
parent
1d09d592d3
commit
d5c69400cd
1
setup.py
1
setup.py
|
@ -61,7 +61,6 @@ extra_require = {
|
|||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
"av": ["av>=13.0.0"],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -275,15 +275,10 @@ class LlavaNextPlugin(BasePlugin):
|
|||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
if "pixel_values" in mm_inputs:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
@ -316,6 +311,7 @@ class LlavaNextPlugin(BasePlugin):
|
|||
res = self._get_mm_inputs(images, videos, processor)
|
||||
return res
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
|
@ -329,16 +325,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
|
@ -379,6 +365,27 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
video_processor = getattr(processor, "video_processor")
|
||||
res = super()._get_mm_inputs(images, [], processor)
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "image_resolution", 168),
|
||||
video_fps=getattr(processor, "video_fps", 1.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 16),
|
||||
)
|
||||
video_res = video_processor(videos, return_tensors="pt")
|
||||
res.update(video_res)
|
||||
return res
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
|
@ -390,26 +397,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
video_processor = getattr(processor, "video_processor")
|
||||
res = self._get_mm_inputs(images, [], processor)
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(videos)
|
||||
video_res = video_processor(videos, return_tensors="pt")
|
||||
res.update(video_res)
|
||||
return res
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
videos = super()._regularize_videos(
|
||||
videos,
|
||||
image_resolution=168,
|
||||
video_fps=1.0,
|
||||
video_maxlen=16,
|
||||
)
|
||||
return videos
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
|
@ -579,7 +567,22 @@ class VideoLlavaPlugin(BasePlugin):
|
|||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_frames = 0
|
||||
exist_images = "pixel_values_images" in mm_inputs
|
||||
exist_videos = "pixel_values_videos" in mm_inputs
|
||||
if exist_videos or exist_images:
|
||||
if exist_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if exist_videos:
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
|
@ -588,39 +591,15 @@ class VideoLlavaPlugin(BasePlugin):
|
|||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values_images" in mm_inputs.keys():
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
|
||||
if "pixel_values_videos" in mm_inputs.keys():
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = num_image_tokens * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
while self.video_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
|
||||
|
||||
return messages
|
||||
|
||||
|
@ -637,19 +616,6 @@ class VideoLlavaPlugin(BasePlugin):
|
|||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
videos = super()._regularize_videos(
|
||||
videos,
|
||||
image_resolution=224,
|
||||
video_fps=1.0,
|
||||
video_maxlen=8,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
|
|
|
@ -25,7 +25,7 @@ from .model_utils.misc import register_autoclass
|
|||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .model_utils.visual import get_image_seqlen
|
||||
from .model_utils.visual import get_image_seqlen, get_patch_size, get_vision_feature_select_strategy
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
|
@ -100,9 +100,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "patch_size", get_patch_size(config))
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
|
|
|
@ -161,6 +161,22 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
|||
return image_seqlen
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
patch_size = getattr(config.vision_config, "patch_size", 14)
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
||||
return vision_feature_select_strategy
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> Union[str, List[str]]:
|
||||
|
|
Loading…
Reference in New Issue