# coding=utf-8 # Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Visit http://localhost:8000/docs for document. import time import torch import uvicorn from threading import Thread from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from transformers import TextIteratorStreamer from sse_starlette import EventSourceResponse from typing import Any, Dict, List, Literal, Optional from utils import ( Template, load_pretrained, prepare_infer_args, get_logits_processor ) @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ModelCard(BaseModel): id: str object: Optional[str] = "model" created: Optional[int] = Field(default_factory=lambda: int(time.time())) owned_by: Optional[str] = "owner" root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = [] class ModelList(BaseModel): object: Optional[str] = "list" data: Optional[List[ModelCard]] = [] class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None n: Optional[int] = 1 max_tokens: Optional[int] = None stream: Optional[bool] = False class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length"]] = None class ChatCompletionResponseUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: Optional[str] = "chatcmpl-default" object: Literal["chat.completion"] created: Optional[int] = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] usage: ChatCompletionResponseUsage class ChatCompletionStreamResponse(BaseModel): id: Optional[str] = "chatcmpl-default" object: Literal["chat.completion.chunk"] created: Optional[int] = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] @app.get("/v1/models", response_model=ModelList) async def list_models(): global model_args model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer, source_prefix, generating_args if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == "system": prefix = prev_messages.pop(0).content else: prefix = source_prefix history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt") inputs = inputs.to(model.device) gen_kwargs = generating_args.to_dict() gen_kwargs.update({ "input_ids": inputs["input_ids"], "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], "logits_processor": get_logits_processor() }) if request.max_tokens: gen_kwargs.pop("max_length", None) gen_kwargs["max_new_tokens"] = request.max_tokens if request.stream: generate = predict(gen_kwargs, request.model) return EventSourceResponse(generate, media_type="text/event-stream") generation_output = model.generate(**gen_kwargs) outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) usage = ChatCompletionResponseUsage( prompt_tokens=len(inputs["input_ids"][0]), completion_tokens=len(outputs), total_tokens=len(inputs["input_ids"][0]) + len(outputs) ) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") async def predict(gen_kwargs: Dict[str, Any], model_id: str): global model, tokenizer streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in streamer: if len(new_text) == 0: continue choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) yield "[DONE]" if __name__ == "__main__": model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)