This commit is contained in:
BUAADreamer 2024-09-29 20:55:23 +08:00
commit 87c8a7e759
1 changed files with 12 additions and 25 deletions

View File

@ -158,6 +158,7 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key input_dict = {"images": None} # default key
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -174,10 +175,17 @@ class BasePlugin:
video_maxlen=getattr(processor, "video_maxlen", 64), video_maxlen=getattr(processor, "video_maxlen", 64),
) )
input_dict["videos"] = videos input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
return image_processor(**input_dict, return_tensors="pt") mm_inputs = {}
else: if image_processor != video_processor:
return {} if input_dict.get("images") is not None:
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
return mm_inputs
def process_messages( def process_messages(
self, self,
@ -365,27 +373,6 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
video_processor = getattr(processor, "video_processor")
res = super()._get_mm_inputs(images, [], processor)
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "image_resolution"),
video_fps=getattr(processor, "video_fps"),
video_maxlen=getattr(processor, "video_maxlen"),
)
video_res = video_processor(videos, return_tensors="pt")
res.update(video_res)
return res
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,