speaker / app.py
QLWD's picture
Update app.py
cb823a5 verified
raw
history blame
7.01 kB
import torch
import spaces
import gradio as gr
import os
from pyannote.audio import Pipeline
from pydub import AudioSegment
# 获取 Hugging Face 认证令牌
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
pipeline = None
# 尝试加载 pyannote 模型
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
# 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
def combine_audio_with_time(target_audio, mixed_audio):
if pipeline is None:
return "错误: 模型未初始化"
# 打印文件路径,确保文件正确传递
print(f"目标音频文件路径: {target_audio}")
print(f"混合音频文件路径: {mixed_audio}")
# 加载目标说话人的样本音频
try:
target_audio_segment = AudioSegment.from_wav(target_audio)
except Exception as e:
return f"加载目标音频时出错: {e}"
# 加载混合音频
try:
mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
except Exception as e:
return f"加载混合音频时出错: {e}"
# 记录目标说话人音频的时间点(精确到0.01秒)
target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
# 目标音频的结束时间(拼接后的音频长度)
target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位
# 将目标说话人的音频片段添加到混合音频的最后
final_audio = mixed_audio_segment + target_audio_segment
final_audio.export("final_output.wav", format="wav")
# 返回目标音频的起始时间和结束时间
return {"start_time": target_start_time, "end_time": target_end_time}
# 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
@spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
def diarize_audio(temp_file):
if pipeline is None:
return "错误: 模型未初始化"
try:
diarization = pipeline(temp_file)
except Exception as e:
return f"处理音频时出错: {e}"
# 返回 diarization 类对象
return diarization
# 将时间戳转换为秒
def timestamp_to_seconds(timestamp):
try:
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
except ValueError as e:
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
return None
# 计算时间段的重叠部分(单位:秒)
def calculate_overlap(start1, end1, start2, end2):
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
overlap_duration = max(0, overlap_end - overlap_start)
return overlap_duration
# 获取所有说话人时间段,排除目标音频时间段
def get_matching_segments(target_time, diarization_output):
target_start_time = target_time['start_time']
target_end_time = target_time['end_time']
# 记录每个说话人与目标音频的重叠时间
speaker_overlap = {}
for speech_turn in diarization_output.itertracks(yield_label=True):
start_seconds = speech_turn[0].start
end_seconds = speech_turn[0].end
label = speech_turn[1]
# 计算目标音频时间段与该说话人时间段的重叠时间
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
if overlap > 0:
if label not in speaker_overlap:
speaker_overlap[label] = 0
speaker_overlap[label] += overlap
# 找到与目标音频时间段重叠最多的说话人
max_overlap_speaker = max(speaker_overlap, key=speaker_overlap.get, default=None)
if max_overlap_speaker is None:
return "没有找到匹配的说话人"
# 获取该说话人的所有时间段,排除目标音频的时间段
speaker_segments = []
for speech_turn in diarization_output.itertracks(yield_label=True):
start_seconds = speech_turn[0].start
end_seconds = speech_turn[0].end
label = speech_turn[1]
if label == max_overlap_speaker:
# 计算目标音频时间段与该说话人时间段的重叠时间
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
if overlap == 0: # 如果没有重叠,则保留该时间段
speaker_segments.append((start_seconds, end_seconds))
# 转换时间段为更易读的格式(例如:00:00:03.895 --> 00:00:04.367)
formatted_segments = [
f"{format_time(segment[0])} --> {format_time(segment[1])}" for segment in speaker_segments
]
return formatted_segments
# 格式化时间(秒 -> hh:mm:ss.xxx)
def format_time(seconds):
mins, secs = divmod(seconds, 60)
hrs, mins = divmod(mins, 60)
return f"{int(hrs):02}:{int(mins):02}:{secs:06.3f}"
# 处理音频文件并返回输出
def process_audio(target_audio, mixed_audio):
# 打印文件路径,确保传入的文件有效
print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
# 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
time_dict = combine_audio_with_time(target_audio, mixed_audio)
# 执行说话人分离
diarization_result = diarize_audio("final_output.wav")
if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
return diarization_result, None # 出错时返回错误信息
else:
# 获取重叠最多的说话人的所有匹配时间段
matching_segments = get_matching_segments(time_dict, diarization_result)
if matching_segments:
# 返回匹配的时间段
return "\n".join(matching_segments)
else:
return "没有找到匹配的说话人时间段。"
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("""
# 🗣️ 音频拼接与说话人分类 🗣️
上传目标音频和混合音频,拼接并进行说话人分类。结果包括匹配重叠最多的说话人的时间段。
""")
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
process_button = gr.Button("处理音频")
# 输出结果
diarization_output = gr.Textbox(label="匹配的说话人时间段")
# 点击按钮时触发处理音频
process_button.click(
fn=process_audio,
inputs=[target_audio_input, mixed_audio_input],
outputs=[diarization_output]
)
demo.launch(share=True)