This commit is contained in:
BUAADreamer 2024-09-29 20:38:46 +08:00
parent 6ddea0f3d3
commit 7397827aec
1 changed files with 9 additions and 9 deletions

View File

@ -344,9 +344,9 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs: if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0]) height, width = get_image_size(pixel_values_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
@ -378,9 +378,9 @@ class LlavaNextVideoPlugin(BasePlugin):
if len(videos) != 0: if len(videos) != 0:
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_resolution=getattr(processor, "image_resolution", 168), image_resolution=getattr(processor, "image_resolution"),
video_fps=getattr(processor, "video_fps", 1.0), video_fps=getattr(processor, "video_fps"),
video_maxlen=getattr(processor, "video_maxlen", 16), video_maxlen=getattr(processor, "video_maxlen"),
) )
video_res = video_processor(videos, return_tensors="pt") video_res = video_processor(videos, return_tensors="pt")
res.update(video_res) res.update(video_res)
@ -576,9 +576,9 @@ class VideoLlavaPlugin(BasePlugin):
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1 num_frames = 1
if exist_videos: if exist_videos:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0]) height, width = get_image_size(pixel_values_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default": if processor.vision_feature_select_strategy == "default":