fix #5384
This commit is contained in:
parent
76f2e59504
commit
36665f3001
|
@ -39,9 +39,9 @@ if is_vllm_available():
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from PIL.Image import Image
|
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
from ..data.mm_plugin import ImageInput, VideoInput
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,7 +111,8 @@ class VllmEngine(BaseEngine):
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["Image"] = None,
|
image: Optional["ImageInput"] = None,
|
||||||
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||||
|
@ -195,11 +196,12 @@ class VllmEngine(BaseEngine):
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["Image"] = None,
|
image: Optional["ImageInput"] = None,
|
||||||
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
final_output = None
|
final_output = None
|
||||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||||
async for request_output in generator:
|
async for request_output in generator:
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
|
||||||
|
@ -221,11 +223,12 @@ class VllmEngine(BaseEngine):
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["Image"] = None,
|
image: Optional["ImageInput"] = None,
|
||||||
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
generator = await self._generate(messages, system, tools, image, **input_kwargs)
|
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||||
async for result in generator:
|
async for result in generator:
|
||||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
generated_text = result.outputs[0].text
|
generated_text = result.outputs[0].text
|
||||||
|
|
Loading…
Reference in New Issue