# -*- encoding: utf-8 -*- # File: app.py # Description: None from copy import deepcopy from typing import Dict, List from PIL import Image import io import subprocess import requests import json import base64 import gradio as gr import librosa IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp") VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v") AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a") DEFAULT_SAMPLING_PARAMS = { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "num_beams": 1, "repetition_penalty": 1.2, } MAX_NEW_TOKENS = 1024 def load_image_to_base64(image_path): """Load image and convert to base64 string""" with Image.open(image_path) as img: if img.mode != 'RGB': img = img.convert('RGB') img_byte_arr = io.BytesIO() img.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() return base64.b64encode(img_byte_arr).decode('utf-8') def wav_to_bytes_with_ffmpeg(wav_file_path): process = subprocess.Popen( ['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) out, _ = process.communicate() return base64.b64encode(out).decode('utf-8') def parse_sse_response(response): for line in response.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data = line[6:] # Remove 'data: ' prefix if data == '[DONE]': break try: json_data = json.loads(data) yield json_data['text'] except json.JSONDecodeError: raise gr.Error(f"Failed to parse JSON: {data}") def history2messages(history: List[Dict]) -> List[Dict]: """ Transform gradio history to chat messages. """ messages = [] cur_message = dict() for item in history: if item["role"] == "assistant": if len(cur_message) > 0: messages.append(deepcopy(cur_message)) cur_message = dict() messages.append(deepcopy(item)) continue if "role" not in cur_message: cur_message["role"] = "user" if "content" not in cur_message: cur_message["content"] = dict() if "metadata" not in item: item["metadata"] = {"title": None} if item["metadata"]["title"] is None: cur_message["content"]["text"] = item["content"] elif item["metadata"]["title"] == "image": cur_message["content"]["image"] = load_image_to_base64(item["content"][0]) elif item["metadata"]["title"] == "audio": cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0]) if len(cur_message) > 0: messages.append(cur_message) return messages def check_messages(history, message, audio): has_text = message["text"] and message["text"].strip() has_files = len(message["files"]) > 0 has_audio = audio is not None if not (has_text or has_files or has_audio): raise gr.Error("请输入文字或上传音频/图片后再发送。") audios = [] images = [] for file_msg in message["files"]: if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS): duration = librosa.get_duration(filename=file_msg) if duration > 30: raise gr.Error("音频时长不能超过30秒。") if duration == 0: raise gr.Error("音频时长不能为0秒。") audios.append(file_msg) elif file_msg.endswith(IMAGE_EXTENSIONS): images.append(file_msg) else: filename = file_msg.split("/")[-1] raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.") if len(audios) > 1: raise gr.Error("Please upload only one audio file.") if len(images) > 1: raise gr.Error("Please upload only one image file.") if audio is not None: if len(audios) > 0: raise gr.Error("Please upload only one audio file or record audio.") audios.append(audio) # Append the message to the history for image in images: history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}}) for audio in audios: history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}}) if message["text"]: history.append({"role": "user", "content": message["text"]}) return history, gr.MultimodalTextbox(value=None, interactive=False), None def bot( history: list, top_p: float, top_k: int, temperature: float, repetition_penalty: float, max_new_tokens: int = MAX_NEW_TOKENS, regenerate: bool = False, ): if history and regenerate: history = history[:-1] if not history: return history msgs = history2messages(history) API_URL = "http://8.152.0.142:8000/v1/chat" payload = { "messages": msgs, "sampling_params": { "top_p": top_p, "top_k": top_k, "temperature": temperature, "repetition_penalty": repetition_penalty, "max_new_tokens": max_new_tokens, "num_beams": 3, } } response = requests.get( API_URL, json=payload, headers={'Accept': 'text/event-stream'}, stream=True ) response_text = "" for text in parse_sse_response(response): response_text += text yield history + [{"role": "assistant", "content": response_text}] return response_text def change_state(state): return gr.update(visible=not state), not state def reset_user_input(): return gr.update(value="") if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # 🪐 Chat with Megrez-3B-Omni """ ) chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh') sampling_params_group_hidden_state = gr.State(False) with gr.Row(equal_height=True): chat_input = gr.MultimodalTextbox( file_count="multiple", placeholder="Enter your prompt or upload image/audio here, then press ENTER...", show_label=False, scale=8, file_types=["image", "audio"], interactive=True, # stop_btn=True, ) with gr.Row(equal_height=True): audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", scale=1, max_length=30 ) with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=150): with gr.Row(equal_height=True): regenerate_btn = gr.Button("Regenerate", variant="primary") clear_btn = gr.ClearButton( [chat_input, audio_input, chatbot], ) with gr.Row(): sampling_params_toggle_btn = gr.Button("Sampling Parameters") with gr.Group(visible=False) as sampling_params_group: with gr.Row(): temperature = gr.Slider( minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature" ) repetition_penalty = gr.Slider( minimum=0, maximum=2, value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"], label="Repetition Penalty", ) with gr.Row(): top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p") top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k") with gr.Row(): max_new_tokens = gr.Slider( minimum=1, maximum=MAX_NEW_TOKENS, value=MAX_NEW_TOKENS, label="Max New Tokens", interactive=True, ) sampling_params_toggle_btn.click( change_state, sampling_params_group_hidden_state, [sampling_params_group, sampling_params_group_hidden_state], ) chat_msg = chat_input.submit( check_messages, [chatbot, chat_input, audio_input], [chatbot, chat_input, audio_input], ) bot_msg = chat_msg.then( bot, inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens], outputs=chatbot, api_name="bot_response", ) bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) regenerate_btn.click( bot, inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)], outputs=chatbot, ) demo.launch()