|
import torch |
|
import spaces |
|
import gradio as gr |
|
import os |
|
from pyannote.audio import Pipeline |
|
from pydub import AudioSegment |
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN") |
|
pipeline = None |
|
|
|
|
|
try: |
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device) |
|
except Exception as e: |
|
print(f"Error initializing pipeline: {e}") |
|
pipeline = None |
|
|
|
|
|
def combine_audio_with_time(target_audio, mixed_audio): |
|
if pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
|
|
target_audio_segment = AudioSegment.from_wav(target_audio) |
|
|
|
|
|
mixed_audio_segment = AudioSegment.from_wav(mixed_audio) |
|
|
|
|
|
target_start_time = len(mixed_audio_segment) / 1000 |
|
|
|
|
|
target_end_time = target_start_time + len(target_audio_segment) / 1000 |
|
|
|
|
|
mixed_audio_segment + target_audio_segment |
|
|
|
|
|
return {"start_time": target_start_time, "end_time": target_end_time} |
|
|
|
|
|
@spaces.GPU(duration=60 * 2) |
|
def diarize_audio(temp_file): |
|
if pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
try: |
|
diarization = pipeline(temp_file) |
|
except Exception as e: |
|
return f"处理音频时出错: {e}" |
|
|
|
|
|
return str(diarization) |
|
|
|
|
|
def generate_labels_from_diarization(diarization_output): |
|
labels_path = 'labels.txt' |
|
successful_lines = 0 |
|
|
|
try: |
|
with open(labels_path, 'w') as outfile: |
|
lines = diarization_output.strip().split('\n') |
|
for line in lines: |
|
try: |
|
parts = line.strip()[1:-1].split(' --> ') |
|
start_time = parts[0].strip() |
|
end_time = parts[1].split(']')[0].strip() |
|
label = line.split()[-1].strip() |
|
start_seconds = timestamp_to_seconds(start_time) |
|
end_seconds = timestamp_to_seconds(end_time) |
|
outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n") |
|
successful_lines += 1 |
|
except Exception as e: |
|
print(f"处理行时出错: '{line.strip()}'. 错误: {e}") |
|
print(f"成功处理了 {successful_lines} 行。") |
|
return labels_path if successful_lines > 0 else None |
|
except Exception as e: |
|
print(f"写入文件时出错: {e}") |
|
return None |
|
|
|
|
|
def timestamp_to_seconds(timestamp): |
|
try: |
|
h, m, s = map(float, timestamp.split(':')) |
|
return 3600 * h + 60 * m + s |
|
except ValueError as e: |
|
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}") |
|
return None |
|
|
|
|
|
def process_audio(target_audio, mixed_audio): |
|
|
|
time_dict = combine_audio_with_time(target_audio, mixed_audio) |
|
|
|
|
|
diarization_result = diarize_audio("final_output.wav") |
|
|
|
if diarization_result.startswith("错误"): |
|
return diarization_result, None, None |
|
else: |
|
|
|
label_file = generate_labels_from_diarization(diarization_result) |
|
return diarization_result, label_file, time_dict |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# 🗣️ 音频拼接与说话人分类 🗣️ |
|
上传目标说话人音频和混合音频,拼接并进行说话人分类。结果包括说话人分离输出、标签文件和目标音频的时间段。 |
|
""") |
|
|
|
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频") |
|
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频") |
|
|
|
process_button = gr.Button("处理音频") |
|
|
|
|
|
diarization_output = gr.Textbox(label="说话人分离结果") |
|
label_file_link = gr.File(label="下载标签文件") |
|
time_range_output = gr.Textbox(label="目标音频时间段") |
|
|
|
|
|
process_button.click( |
|
fn=process_audio, |
|
inputs=[target_audio_input, mixed_audio_input], |
|
outputs=[diarization_output, label_file_link, time_range_output] |
|
) |
|
|
|
demo.launch(share=True) |
|
|