speaker / app.py
QLWD's picture
Update app.py
9e4c424 verified
raw
history blame
4.62 kB
import gradio as gr
import os
from pydub import AudioSegment
from pyannote.audio.pipelines import SpeakerDiarization
import torch
# 初始化 pyannote/speaker-diarization 模型
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
pipeline = None
try:
pipeline = SpeakerDiarization.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
# 音频处理函数:拼接目标音频和混合音频
def combine_audio_with_time(target_audio, mixed_audio):
if pipeline is None:
return "错误: 模型未初始化"
# 加载目标说话人的样本音频
target_audio_segment = AudioSegment.from_wav(target_audio.name)
# 加载混合音频
mixed_audio_segment = AudioSegment.from_wav(mixed_audio.name)
# 记录目标说话人音频的时间点(精确到0.01秒)
target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
# 将目标说话人的音频片段添加到混合音频的最后
final_audio = mixed_audio_segment + target_audio_segment
# 保存拼接后的音频并返回时间点
final_audio.export("final_output.wav", format="wav")
return "final_output.wav", target_start_time
# 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
@spaces.GPU(duration=60 * 2)
def diarize_audio(temp_file):
if pipeline is None:
return "错误: 模型未初始化"
try:
diarization = pipeline(temp_file)
except Exception as e:
return f"处理音频时出错: {e}"
# 返回 diarization 输出
return str(diarization)
# 处理并生成标签文件
def generate_labels_from_diarization(diarization_output):
labels_path = 'labels.txt'
successful_lines = 0
try:
with open(labels_path, 'w') as outfile:
lines = diarization_output.strip().split('\n')
for line in lines:
try:
parts = line.strip()[1:-1].split(' --> ')
start_time = parts[0].strip()
end_time = parts[1].split(']')[0].strip()
label = line.split()[-1].strip()
start_seconds = timestamp_to_seconds(start_time)
end_seconds = timestamp_to_seconds(end_time)
outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n")
successful_lines += 1
except Exception as e:
print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
print(f"成功处理了 {successful_lines} 行。")
return labels_path if successful_lines > 0 else None
except Exception as e:
print(f"写入文件时出错: {e}")
return None
# 将时间戳转换为秒
def timestamp_to_seconds(timestamp):
try:
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
except ValueError as e:
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
return None
@spaces.GPU(duration=60 * 2)
# 处理音频文件
def process_audio(audio):
diarization_result = diarize_audio(save_audio(audio))
if diarization_result.startswith("错误"):
return diarization_result, None # 如果出错,返回错误信息和空的标签文件
else:
label_file = generate_labels_from_diarization(diarization_result)
return diarization_result, label_file
# 保存上传的音频
def save_audio(audio):
with open(audio.name, "rb") as f:
audio_data = f.read()
# 保存上传的音频文件到临时位置
with open("temp.wav", "wb") as f:
f.write(audio_data)
return "temp.wav"
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("""
# 🗣️ 音频拼接与说话人分类 🗣️
上传目标说话人音频和混合音频,拼接并进行说话人分类。
""")
audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
process_button = gr.Button("处理音频")
diarization_output = gr.Textbox(label="说话人分离结果")
label_file_link = gr.File(label="下载标签文件")
# 处理音频
process_button.click(
fn=process_audio,
inputs=[audio_input],
outputs=[diarization_output, label_file_link]
)
demo.launch(share=False)