import base64 import gradio as gr import json import mimetypes import os import requests import time MODEL_VERSION = os.environ['MODEL_VERSION'] API_URL = os.environ['API_URL'] API_KEY = os.environ['API_KEY'] SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) NAME_MAP = { 'system': os.environ.get('SYSTEM_NAME'), 'user': os.environ.get('USER_NAME'), } def respond( message, history, max_tokens, temperature, top_p, ): messages = [] if SYSTEM_PROMPT is not None: messages.append({ 'role': 'system', 'content': SYSTEM_PROMPT, }) for val in history: messages.append({ 'role': val['role'], 'content': convert_content(val['content']), }) messages.append({ 'role': 'user', 'content': convert_content(message), }) for message in messages: add_name_for_message(message) data = { 'model': MODEL_VERSION, 'messages': messages, 'stream': True, 'max_tokens': max_tokens, 'temperature': temperature, 'top_p': top_p, } r = requests.post( API_URL, headers={ 'Content-Type': 'application/json', 'Authorization': 'Bearer {}'.format(API_KEY), }, data=json.dumps(data), stream=True, ) reply = '' for row in r.iter_lines(): if row.startswith(b'data:'): data = json.loads(row[5:]) if 'choices' not in data: raise gr.Error('request failed') choice = data['choices'][0] if 'delta' in choice: reply += choice['delta']['content'] yield reply elif 'message' in choice: yield choice['message']['content'] def add_name_for_message(message): name = NAME_MAP.get(message['role']) if name is not None: message['name'] = name def convert_content(content): if isinstance(content, str): return content if isinstance(content, tuple): return [{ 'type': 'image_url', 'image_url': { 'url': encode_base64(content[0]), }, }] content_list = [] for key, val in content.items(): if key == 'text': content_list.append({ 'type': 'text', 'text': val, }) elif key == 'files': for f in val: content_list.append({ 'type': 'image_url', 'image_url': { 'url': encode_base64(f), }, }) return content_list def encode_base64(path): guess_type = mimetypes.guess_type(path)[0] if not guess_type.startswith('image/'): raise gr.Error('not an image ({}): {}'.format(guess_type, path)) with open(path, 'rb') as handle: data = handle.read() return 'data:{};base64,{}'.format( guess_type, base64.b64encode(data).decode(), ) demo = gr.ChatInterface( respond, multimodal=MULTIMODAL_FLAG == 'ON', type='messages', additional_inputs=[ gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'), gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'), gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'), ], ) if __name__ == '__main__': demo.queue(default_concurrency_limit=50).launch()