diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 63b4fd10..4c68afd4 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -30,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory torch_gc() -def create_app(): - chat_model = ChatModel(*get_infer_args()) - +def create_app(chat_model: ChatModel): app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -124,5 +122,6 @@ def create_app(): if __name__ == "__main__": - app = create_app() + chat_model = ChatModel(*get_infer_args()) + app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)