Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
# File: app.py | |
# Description: None | |
from copy import deepcopy | |
from typing import Dict, List | |
from PIL import Image | |
import io | |
import subprocess | |
import requests | |
import json | |
import base64 | |
import gradio as gr | |
import librosa | |
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp") | |
VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v") | |
AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a") | |
DEFAULT_SAMPLING_PARAMS = { | |
"top_p": 0.8, | |
"top_k": 100, | |
"temperature": 0.7, | |
"do_sample": True, | |
"num_beams": 1, | |
"repetition_penalty": 1.2, | |
} | |
MAX_NEW_TOKENS = 1024 | |
def load_image_to_base64(image_path): | |
"""Load image and convert to base64 string""" | |
with Image.open(image_path) as img: | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
return base64.b64encode(img_byte_arr).decode('utf-8') | |
def wav_to_bytes_with_ffmpeg(wav_file_path): | |
process = subprocess.Popen( | |
['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE | |
) | |
out, _ = process.communicate() | |
return base64.b64encode(out).decode('utf-8') | |
def parse_sse_response(response): | |
for line in response.iter_lines(): | |
if line: | |
line = line.decode('utf-8') | |
if line.startswith('data: '): | |
data = line[6:] # Remove 'data: ' prefix | |
if data == '[DONE]': | |
break | |
try: | |
json_data = json.loads(data) | |
yield json_data['text'] | |
except json.JSONDecodeError: | |
raise gr.Error(f"Failed to parse JSON: {data}") | |
def history2messages(history: List[Dict]) -> List[Dict]: | |
""" | |
Transform gradio history to chat messages. | |
""" | |
messages = [] | |
cur_message = dict() | |
for item in history: | |
if item["role"] == "assistant": | |
if len(cur_message) > 0: | |
messages.append(deepcopy(cur_message)) | |
cur_message = dict() | |
messages.append(deepcopy(item)) | |
continue | |
if "role" not in cur_message: | |
cur_message["role"] = "user" | |
if "content" not in cur_message: | |
cur_message["content"] = dict() | |
if "metadata" not in item: | |
item["metadata"] = {"title": None} | |
if item["metadata"]["title"] is None: | |
cur_message["content"]["text"] = item["content"] | |
elif item["metadata"]["title"] == "image": | |
cur_message["content"]["image"] = load_image_to_base64(item["content"][0]) | |
elif item["metadata"]["title"] == "audio": | |
cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0]) | |
if len(cur_message) > 0: | |
messages.append(cur_message) | |
return messages | |
def check_messages(history, message, audio): | |
has_text = message["text"] and message["text"].strip() | |
has_files = len(message["files"]) > 0 | |
has_audio = audio is not None | |
if not (has_text or has_files or has_audio): | |
raise gr.Error("请输入文字或上传音频/图片后再发送。") | |
audios = [] | |
images = [] | |
for file_msg in message["files"]: | |
if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS): | |
duration = librosa.get_duration(filename=file_msg) | |
if duration > 30: | |
raise gr.Error("音频时长不能超过30秒。") | |
if duration == 0: | |
raise gr.Error("音频时长不能为0秒。") | |
audios.append(file_msg) | |
elif file_msg.endswith(IMAGE_EXTENSIONS): | |
images.append(file_msg) | |
else: | |
filename = file_msg.split("/")[-1] | |
raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.") | |
if len(audios) > 1: | |
raise gr.Error("Please upload only one audio file.") | |
if len(images) > 1: | |
raise gr.Error("Please upload only one image file.") | |
if audio is not None: | |
if len(audios) > 0: | |
raise gr.Error("Please upload only one audio file or record audio.") | |
audios.append(audio) | |
# Append the message to the history | |
for image in images: | |
history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}}) | |
for audio in audios: | |
history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}}) | |
if message["text"]: | |
history.append({"role": "user", "content": message["text"]}) | |
return history, gr.MultimodalTextbox(value=None, interactive=False), None | |
def bot( | |
history: list, | |
top_p: float, | |
top_k: int, | |
temperature: float, | |
repetition_penalty: float, | |
max_new_tokens: int = MAX_NEW_TOKENS, | |
regenerate: bool = False, | |
): | |
if history and regenerate: | |
history = history[:-1] | |
if not history: | |
return history | |
msgs = history2messages(history) | |
API_URL = "http://8.152.0.142:8000/v1/chat" | |
payload = { | |
"messages": msgs, | |
"sampling_params": { | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"max_new_tokens": max_new_tokens, | |
"num_beams": 3, | |
} | |
} | |
response = requests.get( | |
API_URL, | |
json=payload, | |
headers={'Accept': 'text/event-stream'}, | |
stream=True | |
) | |
response_text = "" | |
for text in parse_sse_response(response): | |
response_text += text | |
yield history + [{"role": "assistant", "content": response_text}] | |
return response_text | |
def change_state(state): | |
return gr.update(visible=not state), not state | |
def reset_user_input(): | |
return gr.update(value="") | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
f""" | |
# 🪐 Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a> | |
""" | |
) | |
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh') | |
sampling_params_group_hidden_state = gr.State(False) | |
with gr.Row(equal_height=True): | |
chat_input = gr.MultimodalTextbox( | |
file_count="multiple", | |
placeholder="Enter your prompt or upload image/audio here, then press ENTER...", | |
show_label=False, | |
scale=8, | |
file_types=["image", "audio"], | |
interactive=True, | |
# stop_btn=True, | |
) | |
with gr.Row(equal_height=True): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
scale=1, | |
max_length=30 | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, min_width=150): | |
with gr.Row(equal_height=True): | |
regenerate_btn = gr.Button("Regenerate", variant="primary") | |
clear_btn = gr.ClearButton( | |
[chat_input, audio_input, chatbot], | |
) | |
with gr.Row(): | |
sampling_params_toggle_btn = gr.Button("Sampling Parameters") | |
with gr.Group(visible=False) as sampling_params_group: | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0, | |
maximum=2, | |
value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"], | |
label="Repetition Penalty", | |
) | |
with gr.Row(): | |
top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p") | |
top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k") | |
with gr.Row(): | |
max_new_tokens = gr.Slider( | |
minimum=1, | |
maximum=MAX_NEW_TOKENS, | |
value=MAX_NEW_TOKENS, | |
label="Max New Tokens", | |
interactive=True, | |
) | |
sampling_params_toggle_btn.click( | |
change_state, | |
sampling_params_group_hidden_state, | |
[sampling_params_group, sampling_params_group_hidden_state], | |
) | |
chat_msg = chat_input.submit( | |
check_messages, | |
[chatbot, chat_input, audio_input], | |
[chatbot, chat_input, audio_input], | |
) | |
bot_msg = chat_msg.then( | |
bot, | |
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens], | |
outputs=chatbot, | |
api_name="bot_response", | |
) | |
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
regenerate_btn.click( | |
bot, | |
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)], | |
outputs=chatbot, | |
) | |
demo.launch() |