|
import json |
|
from typing import List |
|
|
|
import torch |
|
from fastapi import FastAPI, Request, status, HTTPException |
|
from pydantic import BaseModel |
|
from torch.cuda import get_device_properties |
|
from transformers import AutoModel, AutoTokenizer |
|
from sse_starlette.sse import EventSourceResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
|
|
import os |
|
|
|
os.environ['TRANSFORMERS_CACHE'] = ".cache" |
|
|
|
bits = 4 |
|
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so" |
|
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d" |
|
cache_dir = './models' |
|
model_name = 'chatglm-6b-int4' |
|
min_memory = 5.5 |
|
tokenizer = None |
|
model = None |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@app.on_event('startup') |
|
def init(): |
|
global tokenizer, model |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) |
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) |
|
|
|
if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory: |
|
model = model.half().quantize(bits=bits).cuda() |
|
print("Using GPU") |
|
else: |
|
model = model.float().quantize(bits=bits) |
|
if torch.cuda.is_available(): |
|
print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3) |
|
else: |
|
print("No GPU available") |
|
print("Using CPU") |
|
model = model.eval() |
|
if os.environ.get("ngrok_token") is not None: |
|
ngrok_connect() |
|
|
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
|
|
class Body(BaseModel): |
|
messages: List[Message] |
|
model: str |
|
stream: bool |
|
max_tokens: int |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"Hello": "World!"} |
|
|
|
|
|
@app.post("/chat/completions") |
|
async def completions(body: Body, request: Request): |
|
if not body.stream or body.model != model_name: |
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented") |
|
|
|
question = body.messages[-1] |
|
if question.role == 'user': |
|
question = question.content |
|
else: |
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") |
|
|
|
user_question = '' |
|
history = [] |
|
for message in body.messages: |
|
if message.role == 'user': |
|
user_question = message.content |
|
elif message.role == 'system' or message.role == 'assistant': |
|
assistant_answer = message.content |
|
history.append((user_question, assistant_answer)) |
|
|
|
async def event_generator(): |
|
for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)): |
|
if await request.is_disconnected(): |
|
return |
|
yield json.dumps({"response": response[0]}) |
|
yield "[DONE]" |
|
|
|
return EventSourceResponse(event_generator()) |
|
|
|
|
|
def ngrok_connect(): |
|
from pyngrok import ngrok, conf |
|
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) |
|
ngrok.set_auth_token(os.environ["ngrok_token"]) |
|
http_tunnel = ngrok.connect(8000) |
|
print(http_tunnel.public_url) |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run("main:app", reload=True, app_dir=".") |
|
|