import base64 import io from functools import partial import gradio as gr import httpx from const import CSS, FOOTER, HEADER, MODELS, PLACEHOLDER from openai import OpenAI from PIL import Image from cycloud.auth import load_default_credentials def get_headers(host: str) -> dict: creds = load_default_credentials() return { "Authorization": f"Bearer {creds.access_token}", "Host": host, "Accept": "application/json", "Content-Type": "application/json", } def proxy(request: httpx.Request, model_info: dict) -> httpx.Request: request.url = request.url.copy_with(path=model_info["endpoint"]) request.headers.update(get_headers(host=model_info["host"].replace("https://", ""))) return request def encode_image_with_pillow(image_path: str) -> str: with Image.open(image_path) as img: img.thumbnail((384, 384)) buffered = io.BytesIO() img.convert("RGB").save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def call_chat_api(message, history, model_name): if message["files"]: if isinstance(message["files"], dict): image = message["files"]["path"] else: image = message["files"][-1] else: for hist in history: if isinstance(hist[0], tuple): image = hist[0][0] img_base64 = encode_image_with_pillow(image) history_openai_format = [ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{img_base64}", }, }, ], } ] if len(history) == 0: history_openai_format[0]["content"].append( {"type": "text", "text": message["text"]} ) else: for human, assistant in history[1:]: if len(history_openai_format) == 1: history_openai_format[0]["content"].append( {"type": "text", "text": human} ) else: history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": message["text"]}) client = OpenAI( api_key="", base_url=MODELS[model_name]["host"], http_client=httpx.Client( event_hooks={ "request": [partial(proxy, model_info=MODELS[model_name])], }, verify=False, ), ) stream = client.chat.completions.create( model=f"/data/cyberagent/{model_name}", messages=history_openai_format, temperature=0.2, top_p=1.0, max_tokens=1024, stream=True, extra_body={"repetition_penalty": 1.1}, ) message = "" for chunk in stream: content = chunk.choices[0].delta.content or "" message = message + content yield message def run(): chatbot = gr.Chatbot( elem_id="chatbot", placeholder=PLACEHOLDER, scale=1, height=700 ) chat_input = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False, ) with gr.Blocks(css=CSS) as demo: gr.Markdown(HEADER) with gr.Row(): model_selector = gr.Dropdown( choices=MODELS.keys(), value=list(MODELS.keys())[0], label="Model", ) gr.ChatInterface( fn=call_chat_api, stop_btn="Stop Generation", examples=[ [ { "text": "この画像を詳しく説明してください。", "files": ["./examples/cat.jpg"], }, ], [ { "text": "この料理はどんな味がするか詳しく教えてください。", "files": ["./examples/takoyaki.jpg"], }, ], ], multimodal=True, textbox=chat_input, chatbot=chatbot, additional_inputs=[model_selector], ) gr.Markdown(FOOTER) demo.queue().launch(share=False) if __name__ == "__main__": run()