diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 7d34965a..5ac26623 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -39,9 +39,9 @@ if is_vllm_available(): if TYPE_CHECKING: - from PIL.Image import Image from transformers.image_processing_utils import BaseImageProcessor + from ..data.mm_plugin import ImageInput, VideoInput from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -111,7 +111,8 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) @@ -195,11 +196,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: final_output = None - generator = await self._generate(messages, system, tools, image, **input_kwargs) + generator = await self._generate(messages, system, tools, image, video, **input_kwargs) async for request_output in generator: final_output = request_output @@ -221,11 +223,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" - generator = await self._generate(messages, system, tools, image, **input_kwargs) + generator = await self._generate(messages, system, tools, image, video, **input_kwargs) async for result in generator: delta_text = result.outputs[0].text[len(generated_text) :] generated_text = result.outputs[0].text