import gc import json import time import requests import base64 import uvicorn import argparse import torch from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedModel, PreTrainedTokenizer, \ TextIteratorStreamer, CodeGenTokenizerFast as Tokenizer from contextlib import asynccontextmanager from loguru import logger from typing import List, Literal, Union, Tuple, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from PIL import Image from io import BytesIO import os import re from threading import Thread from moondream import Moondream, detect_device import omnichat # 请求 class TextContent(BaseModel): type: Literal["text"] text: str class ImageUrl(BaseModel): url: str class ImageUrlContent(BaseModel): type: Literal["image_url"] image_url: ImageUrl ContentItem = Union[TextContent, ImageUrlContent] class ChatMessageInput(BaseModel): role: Literal["user", "assistant", "system"] content: Union[str, List[ContentItem]] name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessageInput] temperature: Optional[float] = 0.8 top_p: Optional[float] = 0.8 max_tokens: Optional[int] = None stream: Optional[bool] = False # Additional parameters repetition_penalty: Optional[float] = 1.0 # 响应 class ChatMessageResponse(BaseModel): role: Literal["assistant"] content: str = None name: Optional[str] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessageResponse class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) usage: Optional[UsageInfo] = None # 图片输入处理 def process_img(input_data): if isinstance(input_data, str): # URL if input_data.startswith("http://") or input_data.startswith("https://"): response = requests.get(input_data) image_data = response.content pil_image = Image.open(BytesIO(image_data)).convert('RGB') # base64 elif input_data.startswith("data:image/"): base64_data = input_data.split(",")[1] image_data = base64.b64decode(base64_data) pil_image = Image.open(BytesIO(image_data)).convert('RGB') # img_path else: pil_image = Image.open(input_data) # PIL elif isinstance(input_data, Image.Image): pil_image = input_data else: raise ValueError("data type error") return pil_image # 历史消息处理 def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[ Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: formatted_history = [] image_list = [] last_user_query = '' for i, message in enumerate(messages): role = message.role content = message.content if isinstance(content, list): # text text_content = ' '.join(item.text for item in content if isinstance(item, TextContent)) else: text_content = content if isinstance(content, list): # image for item in content: if isinstance(item, ImageUrlContent): image_url = item.image_url.url image = process_img(image_url) image_list.append(image) if role == 'user': if i == len(messages) - 1: # last message last_user_query = text_content else: formatted_history.append((text_content, '')) elif role == 'assistant': if formatted_history: if formatted_history[-1][1] != '': assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" formatted_history[-1] = (formatted_history[-1][0], text_content) else: assert False, f"assistant reply before user" else: assert False, f"unrecognized role: {role}" return last_user_query, formatted_history, image_list @torch.inference_mode() # Moondrean推理 def generate_stream_moondream(params: dict): global model, tokenizer # 输入处理 def chat_history_to_prompt(history): prompt = "" for i, (old_query, response) in enumerate(history): prompt += f"Question: {old_query}\n\nAnswer: {response}\n\n" return prompt messages = params["messages"] prompt, formatted_history, image_list = process_history_and_images(messages) history = chat_history_to_prompt(formatted_history) # 只处理最后一张图 img = image_list[-1] # 构建输入 ''' answer_question( self, image_embeds, question, tokenizer, chat_history="", result_queue=None, **kwargs, ) ''' image_embeds = model.encode_image(img) streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) gen_kwargs = { "image_embeds": image_embeds, "question": prompt, "tokenizer": tokenizer, "chat_history": history, "result_queue": None, "streamer": streamer, } thread = Thread( target=model.answer_question, kwargs=gen_kwargs, ) input_echo_len = 0 total_len = 0 # 启动推理 thread.start() buffer = "" for new_text in streamer: clean_text = re.sub("<$|END$", "", new_text) buffer += clean_text yield { "text": buffer.strip(" 1e-5: gen_kwargs["temperature"] = temperature total_len = 0 generated_text = "" with torch.no_grad(): model.generate(**inputs, **gen_kwargs) for next_text in streamer: generated_text += next_text yield { "text": generated_text, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": total_len - input_echo_len, "total_tokens": total_len, }, } ret = { "text": generated_text, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": total_len - input_echo_len, "total_tokens": total_len, }, } yield ret # CogVLM单次响应 def generate_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): for response in generate_stream_cogvlm(model, tokenizer, params): pass return response def generate_minicpm(model, params): messages = params["messages"] query, history, image_list = process_history_and_images(messages) msgs = history msgs.append({'role': 'user', 'content': query}) image = image_list[-1] # image is a PIL image buffer = BytesIO() image.save(buffer, format="JPEG") # You can adjust the format as needed buffer.seek(0) image_base64 = base64.b64encode(buffer.read()) image_base64_str = image_base64.decode("utf-8") input = {'image': image_base64_str, 'question': json.dumps(msgs)} generation = model.chat(input) response = {"text": generation, "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}} print(response) return response # 流式响应 async def predict(model_id: str, params: dict): return "no stream" torch.set_grad_enabled(False) # 生命周期管理器,结束清显存 @asynccontextmanager async def lifespan(app: FastAPI): yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) # 允许跨域 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 对话路由 @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer # 检查请求 if len(request.messages) < 1 or request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request") gen_params = dict( messages=request.messages, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens or 1024, echo=False, stream=request.stream, ) # 流式响应 if request.stream: generate = predict(request.model, gen_params) return # 单次响应 if STATE_MOD == "cog": response = generate_cogvlm(model, tokenizer, gen_params) elif STATE_MOD == "moon": response = generate_moondream(gen_params) elif STATE_MOD == "mini": response = generate_minicpm(model, gen_params) usage = UsageInfo() message = ChatMessageResponse( role="assistant", content=response["text"], ) logger.debug(f"==== message ====\n{message}") choice_data = ChatCompletionResponseChoice( index=0, message=message, ) task_usage = UsageInfo.model_validate(response["usage"]) for usage_key, usage_value in task_usage.model_dump().items(): setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) # 模型切换路由配置 STATE_MOD = "moon" MODEL_PATH = "" # 模型加载 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' def load_mod(model_input, mod_type): global model, tokenizer, language_processor_version if mod_type == "cog": tokenizer_path = os.environ.get("TOKENIZER_PATH", 'lmsys/vicuna-7b-v1.5') tokenizer = LlamaTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, signal_type=language_processor_version ) if 'cuda' in DEVICE: model = AutoModelForCausalLM.from_pretrained( model_input, trust_remote_code=True, load_in_4bit=True, torch_dtype=torch_type, low_cpu_mem_usage=True ).eval() else: model = AutoModelForCausalLM.from_pretrained( model_input, trust_remote_code=True ).float().to(DEVICE).eval() elif mod_type == "moon": device, dtype = detect_device() model = Moondream.from_pretrained(model_input).to(device=device, dtype=dtype).eval() tokenizer = Tokenizer.from_pretrained(model_input) elif mod_type == "mini": model, tokenizer = omnichat.OmniLMMChat(model_input), None @app.post("/v1/Cog-vqa") async def switch_vqa(): global model, STATE_MOD, mod_vqa, language_processor_version STATE_MOD = "cog" del model model = None language_processor_version = "chat_old" load_mod(mod_vqa, STATE_MOD) @app.post("/v1/Cog-chat") async def switch_chat(): global model, STATE_MOD, mod_chat, language_processor_version STATE_MOD = "cog" del model model = None language_processor_version = "chat" load_mod(mod_chat, STATE_MOD) @app.post("/v1/moondream") async def switch_moon(): global model, STATE_MOD, mod_moon STATE_MOD = "moon" del model model = None load_mod(mod_moon, STATE_MOD) @app.post("/v1/MiniCPM") async def switch_mini(): global model, STATE_MOD, mod_mini STATE_MOD = "mini" del model model = None load_mod(mod_mini, STATE_MOD) # 关闭 @app.post("/v1/close") async def close(): global model del model model = None gc.collect() parser = argparse.ArgumentParser() parser.add_argument("--mod", type=str, default="moondrean") args = parser.parse_args() mod = args.mod mod_vqa = './models/cogagent-vqa-hf' mod_chat = './models/cogagent-chat-hf' mod_moon = './models/moondream' mod_mini = './models/MiniCPM-Llama3-V-2_5' ''' mod_list = [ "moondrean", "Cog-vqa", "Cog-chat" "MiniCPM" ] ''' if mod == "Cog-vqa": STATE_MOD = "cog" MODEL_PATH = mod_vqa language_processor_version = "chat_old" elif mod == "Cog-chat": STATE_MOD = "cog" MODEL_PATH = mod_chat language_processor_version = "chat" elif mod == "moondream": STATE_MOD = "moon" MODEL_PATH = mod_moon elif mod == "MiniCPM": STATE_MOD = "mini" MODEL_PATH = mod_mini if __name__ == "__main__": if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: torch_type = torch.bfloat16 else: torch_type = torch.float16 print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) load_mod(MODEL_PATH, STATE_MOD) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)