This commit is contained in:
hoshi-hiyouga 2024-09-07 01:21:14 +08:00 committed by GitHub
parent 76f2e59504
commit 36665f3001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 6 deletions

View File

@ -39,9 +39,9 @@ if is_vllm_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -111,7 +111,8 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
@ -195,11 +196,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@ -221,11 +223,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text