Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
This commit is contained in:
commit
87c8a7e759
|
@ -158,6 +158,7 @@ class BasePlugin:
|
||||||
It holds num_patches == torch.prod(image_grid_thw)
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||||
input_dict = {"images": None} # default key
|
input_dict = {"images": None} # default key
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
|
@ -174,10 +175,17 @@ class BasePlugin:
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
)
|
)
|
||||||
input_dict["videos"] = videos
|
input_dict["videos"] = videos
|
||||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
|
||||||
return image_processor(**input_dict, return_tensors="pt")
|
mm_inputs = {}
|
||||||
else:
|
if image_processor != video_processor:
|
||||||
return {}
|
if input_dict.get("images") is not None:
|
||||||
|
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||||
|
if input_dict.get("videos") is not None:
|
||||||
|
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||||
|
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||||
|
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
|
@ -365,27 +373,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
video_processor = getattr(processor, "video_processor")
|
|
||||||
res = super()._get_mm_inputs(images, [], processor)
|
|
||||||
if len(videos) != 0:
|
|
||||||
videos = self._regularize_videos(
|
|
||||||
videos,
|
|
||||||
image_resolution=getattr(processor, "image_resolution"),
|
|
||||||
video_fps=getattr(processor, "video_fps"),
|
|
||||||
video_maxlen=getattr(processor, "video_maxlen"),
|
|
||||||
)
|
|
||||||
video_res = video_processor(videos, return_tensors="pt")
|
|
||||||
res.update(video_res)
|
|
||||||
return res
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue