import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os import copy import re import secrets from pathlib import Path from pydub import AudioSegment # Initialize the model and tokenizer torch.manual_seed(420) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-Audio-Chat", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="cuda", trust_remote_code=True).eval() def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(_chatbot, task_history): if not task_history: return _chatbot query = task_history[-1][0] history_cp = copy.deepcopy(task_history) history_filter = [] audio_idx = 1 pre = "" last_audio = None for i, (q, a) in enumerate(history_cp): if isinstance(q, (tuple, list)): last_audio = q[0] q = f'Audio {audio_idx}: ' pre += q + '\n' audio_idx += 1 else: pre += q history_filter.append((pre, a)) pre = "" history, message = history_filter[:-1], history_filter[-1][0] response, history = model.chat(tokenizer, message, history=history) ts_pattern = r"<\|\d{1,2}\.\d+\|>" all_time_stamps = re.findall(ts_pattern, response) if (len(all_time_stamps) > 0) and (len(all_time_stamps) % 2 ==0) and last_audio: ts_float = [ float(t.replace("<|","").replace("|>","")) for t in all_time_stamps] ts_float_pair = [ts_float[i:i + 2] for i in range(0,len(all_time_stamps),2)] # 读取音频文件 format = os.path.splitext(last_audio)[-1].replace(".","") audio_file = AudioSegment.from_file(last_audio, format=format) chat_response_t = response.replace("<|", "").replace("|>", "") chat_response = chat_response_t temp_dir = secrets.token_hex(20) temp_dir = Path(uploaded_file_dir) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) # 截取音频文件 for pair in ts_float_pair: audio_clip = audio_file[pair[0] * 1000: pair[1] * 1000] # 保存音频文件 name = f"tmp{secrets.token_hex(5)}.{format}" filename = temp_dir / name audio_clip.export(filename, format=format) _chatbot[-1] = (_parse_text(query), chat_response) _chatbot.append((None, (str(filename),))) else: _chatbot[-1] = (_parse_text(query), response) full_response = _parse_text(response) task_history[-1] = (query, full_response) print("Qwen-Audio-Chat: " + _parse_text(full_response)) return _chatbot def regenerate(_chatbot, task_history): if not task_history: return _chatbot item = task_history[-1] if item[1] is None: return _chatbot task_history[-1] = (item[0], None) chatbot_item = _chatbot.pop(-1) if chatbot_item[0] is None: _chatbot[-1] = (_chatbot[-1][0], None) else: _chatbot.append((chatbot_item[0], None)) return predict(_chatbot, task_history) def add_text(history, task_history, text): history = history + [(_parse_text(text), None)] task_history = task_history + [(text, None)] return history, task_history, "" def add_file(history, task_history, file): history = history + [((file.name,), None)] task_history = task_history + [((file.name,), None)] return history, task_history def add_mic(history, task_history, file): if file is None: return history, task_history os.rename(file, file + '.wav') print("add_mic file:", file) print("add_mic history:", history) print("add_mic task_history:", task_history) # history = history + [((file.name,), None)] # task_history = task_history + [((file.name,), None)] task_history = task_history + [((file + '.wav',), None)] history = history + [((file + '.wav',), None)] print("task_history", task_history) return history, task_history def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] iface = gr.Interface( fn=predict, inputs=[ gr.inputs.Audio(label="Audio Input"), gr.inputs.Textbox(label="Text Query"), gr.State() ], outputs=[ "text", gr.State() ], title="Audio-Text Interaction Model", description="This model can process an audio input along with a text query and provide a response.", theme="default", allow_flagging="never" ) iface.launch()