import torch import spaces import gradio as gr import os from pyannote.audio import Pipeline from pyannote.core import Annotation, Segment from pydub import AudioSegment # 获取 Hugging Face 认证令牌 HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN") class AudioProcessor: def __init__(self): self.pipeline = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 尝试加载 pyannote 模型 try: self.pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN ) self.pipeline.to(self.device) print("pyannote model loaded successfully.") except Exception as e: print(f"Error initializing pipeline: {e}") self.pipeline = None # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典 def combine_audio_with_time(self, target_audio, mixed_audio): if self.pipeline is None: return "错误: 模型未初始化" # 加载目标说话人的样本音频 try: target_audio_segment = AudioSegment.from_wav(target_audio) except Exception as e: return f"加载目标音频时出错: {e}" # 加载混合音频 try: mixed_audio_segment = AudioSegment.from_wav(mixed_audio) except Exception as e: return f"加载混合音频时出错: {e}" # 记录目标说话人音频的时间点(精确到0.01秒) target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒 # 目标音频的结束时间(拼接后的音频长度) target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位 # 将目标说话人的音频片段添加到混合音频的最后 final_audio = mixed_audio_segment + target_audio_segment final_audio.export("final_output.wav", format="wav") # 返回目标音频的起始时间和结束时间 return {"start_time": target_start_time, "end_time": target_end_time} # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离 @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒 def diarize_audio(self, temp_file): if self.pipeline is None: return "错误: 模型未初始化" try: diarization = self.pipeline(temp_file) # 返回 Annotation 对象 except Exception as e: return f"处理音频时出错: {e}" return diarization # 直接返回 Annotation 对象 # 将时间戳转换为秒 def timestamp_to_seconds(self, timestamp): try: h, m, s = map(float, timestamp.split(':')) return 3600 * h + 60 * m + s except ValueError as e: print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}") return None # 计算时间段的重叠部分(单位:秒) def calculate_overlap(self, start1, end1, start2, end2): overlap_start = max(start1, start2) overlap_end = min(end1, end2) overlap_duration = max(0, overlap_end - overlap_start) return overlap_duration # 获取目标时间段和说话人时间段的重叠比例 def get_best_match(self, target_time, diarization_output): target_start_time = target_time['start_time'] target_end_time = target_time['end_time'] # 用于存储每个说话人时间段的重叠比例 speaker_segments = [] for segment, label in diarization_output.itertracks(yield_label=True): try: start_seconds = segment.start end_seconds = segment.end # 计算目标音频时间段和说话人时间段的重叠时间 overlap = self.calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds) overlap_ratio = overlap / (end_seconds - start_seconds) # 记录说话人标签和重叠比例 speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds)) except Exception as e: print(f"处理行时出错: '{segment}'. 错误: {e}") # 按照重叠比例排序,返回重叠比例最大的一段 best_match = max(speaker_segments, key=lambda x: x[1], default=None) return best_match # 获取该说话人除了目标语音时间段外的所有时间段 def get_speaker_time_segments(self, diarization_output, target_time, speaker_label): remaining_segments = [] # 遍历 diarization 输出,查找该说话人的所有时间段 for segment, label in diarization_output.itertracks(yield_label=True): if label == speaker_label: start_seconds = segment.start end_seconds = segment.end # 计算与目标音频的重叠部分 overlap_start = max(start_seconds, target_time['start_time']) overlap_end = min(end_seconds, target_time['end_time']) # 如果有重叠部分,排除重叠部分 if overlap_start < overlap_end: if start_seconds < overlap_start: remaining_segments.append((start_seconds, overlap_start)) if overlap_end < end_seconds: remaining_segments.append((overlap_end, end_seconds)) else: remaining_segments.append((start_seconds, end_seconds)) return remaining_segments # 处理音频文件并返回输出 def process_audio(self, target_audio, mixed_audio): # 进行音频拼接并返回目标音频的起始和结束时间(作为字典) time_dict = self.combine_audio_with_time(target_audio, mixed_audio) # 执行说话人分离 diarization_result = self.diarize_audio("final_output.wav") if isinstance(diarization_result, str) and diarization_result.startswith("错误"): return diarization_result, None # 出错时返回错误信息 else: # 获取最佳匹配的说话人标签和时间段 best_match = self.get_best_match(time_dict, diarization_result) if best_match: speaker_label = best_match[0] # 取出最佳匹配说话人的标签 # 获取该说话人除了目标语音时间段外的所有时间段 remaining_segments = self.get_speaker_time_segments(diarization_result, time_dict, speaker_label) return speaker_label, remaining_segments # Gradio 接口 with gr.Blocks() as demo: gr.Markdown(""" # 🗣️ 音频拼接与说话人分类 🗣️ 上传目标音频和混合音频,拼接并进行说话人分类。结果包括最佳匹配说话人的时间段(排除目标音频时间段)。 """) mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频") target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频") process_button = gr.Button("处理音频") # 输出结果 diarization_output = gr.Textbox(label="最佳匹配说话人") time_range_output = gr.Textbox(label="最佳匹配时间段") # 点击按钮时触发处理音频 process_button.click( fn=AudioProcessor().process_audio, inputs=[target_audio_input, mixed_audio_input], outputs=[diarization_output, time_range_output] ) demo.launch(share=True)