|
import torch |
|
import spaces |
|
import gradio as gr |
|
import os |
|
from pyannote.audio import Pipeline |
|
from pyannote.core import Annotation, Segment |
|
from pydub import AudioSegment |
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN") |
|
|
|
class AudioProcessor: |
|
def __init__(self): |
|
self.pipeline = None |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
try: |
|
self.pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN |
|
) |
|
self.pipeline.to(self.device) |
|
print("pyannote model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error initializing pipeline: {e}") |
|
self.pipeline = None |
|
|
|
|
|
def combine_audio_with_time(self, target_audio, mixed_audio): |
|
if self.pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
|
|
try: |
|
target_audio_segment = AudioSegment.from_wav(target_audio) |
|
except Exception as e: |
|
return f"加载目标音频时出错: {e}" |
|
|
|
|
|
try: |
|
mixed_audio_segment = AudioSegment.from_wav(mixed_audio) |
|
except Exception as e: |
|
return f"加载混合音频时出错: {e}" |
|
|
|
|
|
target_start_time = len(mixed_audio_segment) / 1000 |
|
|
|
|
|
target_end_time = target_start_time + len(target_audio_segment) / 1000 |
|
|
|
|
|
final_audio = mixed_audio_segment + target_audio_segment |
|
final_audio.export("final_output.wav", format="wav") |
|
|
|
|
|
return {"start_time": target_start_time, "end_time": target_end_time} |
|
|
|
|
|
@spaces.GPU(duration=60 * 2) |
|
def diarize_audio(self, temp_file): |
|
if self.pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
try: |
|
diarization = self.pipeline(temp_file) |
|
except Exception as e: |
|
return f"处理音频时出错: {e}" |
|
|
|
return diarization |
|
|
|
|
|
def timestamp_to_seconds(self, timestamp): |
|
try: |
|
h, m, s = map(float, timestamp.split(':')) |
|
return 3600 * h + 60 * m + s |
|
except ValueError as e: |
|
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}") |
|
return None |
|
|
|
|
|
def calculate_overlap(self, start1, end1, start2, end2): |
|
overlap_start = max(start1, start2) |
|
overlap_end = min(end1, end2) |
|
overlap_duration = max(0, overlap_end - overlap_start) |
|
return overlap_duration |
|
|
|
|
|
def get_best_match(self, target_time, diarization_output): |
|
target_start_time = target_time['start_time'] |
|
target_end_time = target_time['end_time'] |
|
|
|
|
|
if not isinstance(diarization_output, Annotation): |
|
print(f"Error: Expected an Annotation object, but got {type(diarization_output)}") |
|
return None |
|
|
|
|
|
speaker_segments = [] |
|
for segment, label in diarization_output.itertracks(yield_label=True): |
|
try: |
|
start_seconds = segment.start |
|
end_seconds = segment.end |
|
|
|
|
|
overlap = self.calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds) |
|
overlap_ratio = overlap / (end_seconds - start_seconds) |
|
|
|
|
|
speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds)) |
|
|
|
except Exception as e: |
|
print(f"处理行时出错: '{segment}'. 错误: {e}") |
|
|
|
|
|
best_match = max(speaker_segments, key=lambda x: x[1], default=None) |
|
return best_match |
|
|
|
|
|
def get_speaker_time_segments(self, diarization_output, target_time, speaker_label): |
|
remaining_segments = [] |
|
|
|
|
|
for segment, label in diarization_output.itertracks(yield_label=True): |
|
if label == speaker_label: |
|
start_seconds = segment.start |
|
end_seconds = segment.end |
|
|
|
|
|
overlap_start = max(start_seconds, target_time['start_time']) |
|
overlap_end = min(end_seconds, target_time['end_time']) |
|
|
|
|
|
if overlap_start < overlap_end: |
|
if start_seconds < overlap_start: |
|
remaining_segments.append((start_seconds, overlap_start)) |
|
if overlap_end < end_seconds: |
|
remaining_segments.append((overlap_end, end_seconds)) |
|
else: |
|
remaining_segments.append((start_seconds, end_seconds)) |
|
|
|
return remaining_segments |
|
|
|
|
|
def process_audio(self, target_audio, mixed_audio): |
|
|
|
time_dict = self.combine_audio_with_time(target_audio, mixed_audio) |
|
|
|
|
|
diarization_result = self.diarize_audio("final_output.wav") |
|
|
|
if isinstance(diarization_result, str) and diarization_result.startswith("错误"): |
|
return diarization_result, None |
|
else: |
|
|
|
best_match = self.get_best_match(time_dict, diarization_result) |
|
|
|
if best_match: |
|
speaker_label = best_match[0] |
|
|
|
remaining_segments = self.get_speaker_time_segments(diarization_result, time_dict, speaker_label) |
|
return speaker_label, remaining_segments |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# 🗣️ 音频拼接与说话人分类 🗣️ |
|
上传目标音频和混合音频,拼接并进行说话人分类。结果包括最佳匹配说话人的时间段(排除目标音频时间段)。 |
|
""") |
|
|
|
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频") |
|
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频") |
|
|
|
process_button = gr.Button("处理音频") |
|
|
|
|
|
diarization_output = gr.Textbox(label="最佳匹配说话人") |
|
time_range_output = gr.Textbox(label="最佳匹配时间段") |
|
|
|
|
|
process_button.click( |
|
fn=AudioProcessor().process_audio, |
|
inputs=[target_audio_input, mixed_audio_input], |
|
outputs=[diarization_output, time_range_output] |
|
) |
|
|
|
demo.launch(share=True) |
|
|