Megrez-3B-Omni / app.py
yuantao-infini-ai's picture
Init commit
dfe75a1
# -*- encoding: utf-8 -*-
# File: app.py
# Description: None
from copy import deepcopy
from typing import Dict, List
from PIL import Image
import io
import subprocess
import requests
import json
import base64
import gradio as gr
import librosa
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")
VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v")
AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a")
DEFAULT_SAMPLING_PARAMS = {
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"num_beams": 1,
"repetition_penalty": 1.2,
}
MAX_NEW_TOKENS = 1024
def load_image_to_base64(image_path):
"""Load image and convert to base64 string"""
with Image.open(image_path) as img:
if img.mode != 'RGB':
img = img.convert('RGB')
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return base64.b64encode(img_byte_arr).decode('utf-8')
def wav_to_bytes_with_ffmpeg(wav_file_path):
process = subprocess.Popen(
['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
out, _ = process.communicate()
return base64.b64encode(out).decode('utf-8')
def parse_sse_response(response):
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = line[6:] # Remove 'data: ' prefix
if data == '[DONE]':
break
try:
json_data = json.loads(data)
yield json_data['text']
except json.JSONDecodeError:
raise gr.Error(f"Failed to parse JSON: {data}")
def history2messages(history: List[Dict]) -> List[Dict]:
"""
Transform gradio history to chat messages.
"""
messages = []
cur_message = dict()
for item in history:
if item["role"] == "assistant":
if len(cur_message) > 0:
messages.append(deepcopy(cur_message))
cur_message = dict()
messages.append(deepcopy(item))
continue
if "role" not in cur_message:
cur_message["role"] = "user"
if "content" not in cur_message:
cur_message["content"] = dict()
if "metadata" not in item:
item["metadata"] = {"title": None}
if item["metadata"]["title"] is None:
cur_message["content"]["text"] = item["content"]
elif item["metadata"]["title"] == "image":
cur_message["content"]["image"] = load_image_to_base64(item["content"][0])
elif item["metadata"]["title"] == "audio":
cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0])
if len(cur_message) > 0:
messages.append(cur_message)
return messages
def check_messages(history, message, audio):
has_text = message["text"] and message["text"].strip()
has_files = len(message["files"]) > 0
has_audio = audio is not None
if not (has_text or has_files or has_audio):
raise gr.Error("请输入文字或上传音频/图片后再发送。")
audios = []
images = []
for file_msg in message["files"]:
if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS):
duration = librosa.get_duration(filename=file_msg)
if duration > 30:
raise gr.Error("音频时长不能超过30秒。")
if duration == 0:
raise gr.Error("音频时长不能为0秒。")
audios.append(file_msg)
elif file_msg.endswith(IMAGE_EXTENSIONS):
images.append(file_msg)
else:
filename = file_msg.split("/")[-1]
raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.")
if len(audios) > 1:
raise gr.Error("Please upload only one audio file.")
if len(images) > 1:
raise gr.Error("Please upload only one image file.")
if audio is not None:
if len(audios) > 0:
raise gr.Error("Please upload only one audio file or record audio.")
audios.append(audio)
# Append the message to the history
for image in images:
history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}})
for audio in audios:
history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}})
if message["text"]:
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False), None
def bot(
history: list,
top_p: float,
top_k: int,
temperature: float,
repetition_penalty: float,
max_new_tokens: int = MAX_NEW_TOKENS,
regenerate: bool = False,
):
if history and regenerate:
history = history[:-1]
if not history:
return history
msgs = history2messages(history)
API_URL = "http://8.152.0.142:8000/v1/chat"
payload = {
"messages": msgs,
"sampling_params": {
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"num_beams": 3,
}
}
response = requests.get(
API_URL,
json=payload,
headers={'Accept': 'text/event-stream'},
stream=True
)
response_text = ""
for text in parse_sse_response(response):
response_text += text
yield history + [{"role": "assistant", "content": response_text}]
return response_text
def change_state(state):
return gr.update(visible=not state), not state
def reset_user_input():
return gr.update(value="")
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🪐 Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a>
"""
)
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh')
sampling_params_group_hidden_state = gr.State(False)
with gr.Row(equal_height=True):
chat_input = gr.MultimodalTextbox(
file_count="multiple",
placeholder="Enter your prompt or upload image/audio here, then press ENTER...",
show_label=False,
scale=8,
file_types=["image", "audio"],
interactive=True,
# stop_btn=True,
)
with gr.Row(equal_height=True):
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
scale=1,
max_length=30
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=150):
with gr.Row(equal_height=True):
regenerate_btn = gr.Button("Regenerate", variant="primary")
clear_btn = gr.ClearButton(
[chat_input, audio_input, chatbot],
)
with gr.Row():
sampling_params_toggle_btn = gr.Button("Sampling Parameters")
with gr.Group(visible=False) as sampling_params_group:
with gr.Row():
temperature = gr.Slider(
minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature"
)
repetition_penalty = gr.Slider(
minimum=0,
maximum=2,
value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
label="Repetition Penalty",
)
with gr.Row():
top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p")
top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k")
with gr.Row():
max_new_tokens = gr.Slider(
minimum=1,
maximum=MAX_NEW_TOKENS,
value=MAX_NEW_TOKENS,
label="Max New Tokens",
interactive=True,
)
sampling_params_toggle_btn.click(
change_state,
sampling_params_group_hidden_state,
[sampling_params_group, sampling_params_group_hidden_state],
)
chat_msg = chat_input.submit(
check_messages,
[chatbot, chat_input, audio_input],
[chatbot, chat_input, audio_input],
)
bot_msg = chat_msg.then(
bot,
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens],
outputs=chatbot,
api_name="bot_response",
)
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
regenerate_btn.click(
bot,
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)],
outputs=chatbot,
)
demo.launch()