MiniMax-Text-01 / app.py
MiniMax-AI's picture
the first version
78c5953
import base64
import gradio as gr
import json
import mimetypes
import os
import requests
import time
MODEL_VERSION = os.environ['MODEL_VERSION']
API_URL = os.environ['API_URL']
API_KEY = os.environ['API_KEY']
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT')
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL')
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS'])
NAME_MAP = {
'system': os.environ.get('SYSTEM_NAME'),
'user': os.environ.get('USER_NAME'),
}
def respond(
message,
history,
max_tokens,
temperature,
top_p,
):
messages = []
if SYSTEM_PROMPT is not None:
messages.append({
'role': 'system',
'content': SYSTEM_PROMPT,
})
for val in history:
messages.append({
'role': val['role'],
'content': convert_content(val['content']),
})
messages.append({
'role': 'user',
'content': convert_content(message),
})
for message in messages:
add_name_for_message(message)
data = {
'model': MODEL_VERSION,
'messages': messages,
'stream': True,
'max_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
}
r = requests.post(
API_URL,
headers={
'Content-Type': 'application/json',
'Authorization': 'Bearer {}'.format(API_KEY),
},
data=json.dumps(data),
stream=True,
)
reply = ''
for row in r.iter_lines():
if row.startswith(b'data:'):
data = json.loads(row[5:])
if 'choices' not in data:
raise gr.Error('request failed')
choice = data['choices'][0]
if 'delta' in choice:
reply += choice['delta']['content']
yield reply
elif 'message' in choice:
yield choice['message']['content']
def add_name_for_message(message):
name = NAME_MAP.get(message['role'])
if name is not None:
message['name'] = name
def convert_content(content):
if isinstance(content, str):
return content
if isinstance(content, tuple):
return [{
'type': 'image_url',
'image_url': {
'url': encode_base64(content[0]),
},
}]
content_list = []
for key, val in content.items():
if key == 'text':
content_list.append({
'type': 'text',
'text': val,
})
elif key == 'files':
for f in val:
content_list.append({
'type': 'image_url',
'image_url': {
'url': encode_base64(f),
},
})
return content_list
def encode_base64(path):
guess_type = mimetypes.guess_type(path)[0]
if not guess_type.startswith('image/'):
raise gr.Error('not an image ({}): {}'.format(guess_type, path))
with open(path, 'rb') as handle:
data = handle.read()
return 'data:{};base64,{}'.format(
guess_type,
base64.b64encode(data).decode(),
)
demo = gr.ChatInterface(
respond,
multimodal=MULTIMODAL_FLAG == 'ON',
type='messages',
additional_inputs=[
gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'),
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'),
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'),
],
)
if __name__ == '__main__':
demo.queue(default_concurrency_limit=50).launch()