|
import torch |
|
import os |
|
import gradio as gr |
|
from pyannote.audio import Pipeline |
|
from pydub import AudioSegment |
|
from spaces import GPU |
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN") |
|
pipeline = None |
|
|
|
|
|
try: |
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device) |
|
except Exception as e: |
|
print(f"Error initializing pipeline: {e}") |
|
pipeline = None |
|
|
|
|
|
def timestamp_to_seconds(timestamp): |
|
h, m, s = map(float, timestamp.split(':')) |
|
return 3600 * h + 60 * m + s |
|
|
|
def convert_to_wav(audio_file): |
|
try: |
|
|
|
audio = AudioSegment.from_file(audio_file) |
|
|
|
|
|
wav_output = BytesIO() |
|
|
|
|
|
audio.export(wav_output, format="wav") |
|
|
|
|
|
wav_output.seek(0) |
|
|
|
return wav_output |
|
except Exception as e: |
|
return f"音频转换失败: {e}" |
|
|
|
|
|
def combine_audio_with_time(target_audio, mixed_audio): |
|
if pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
|
|
print(f"目标音频文件路径: {target_audio}") |
|
print(f"混合音频文件路径: {mixed_audio}") |
|
|
|
target_audio = convert_to_wav(target_audio) |
|
mixed_audio = convert_to_wav(mixed_audio) |
|
|
|
print(f"目标音频文件路径: {target_audio}") |
|
print(f"混合音频文件路径: {mixed_audio}") |
|
|
|
|
|
try: |
|
target_audio_segment = AudioSegment.from_wav(target_audio) |
|
except Exception as e: |
|
return f"加载目标音频时出错: {e}" |
|
|
|
|
|
try: |
|
mixed_audio_segment = AudioSegment.from_wav(mixed_audio) |
|
except Exception as e: |
|
return f"加载混合音频时出错: {e}" |
|
|
|
|
|
target_start_time = len(mixed_audio_segment) / 1000 |
|
|
|
|
|
target_end_time = target_start_time + len(target_audio_segment) / 1000 |
|
|
|
|
|
final_audio = mixed_audio_segment + target_audio_segment |
|
final_audio.export("final_output.wav", format="wav") |
|
|
|
|
|
return {"start_time": target_start_time, "end_time": target_end_time} |
|
|
|
|
|
@GPU(duration=60 * 2) |
|
def diarize_audio(temp_file): |
|
if pipeline is None: |
|
return "错误: 模型未初始化" |
|
|
|
try: |
|
diarization = pipeline(temp_file) |
|
print("说话人分离结果:") |
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
print(f"[{turn.start:.3f} --> {turn.end:.3f}] {speaker}") |
|
return diarization |
|
except Exception as e: |
|
return f"处理音频时出错: {e}" |
|
|
|
|
|
def find_best_matching_speaker(target_start_time, target_end_time, diarization): |
|
best_match = None |
|
max_overlap = 0 |
|
|
|
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
start = turn.start |
|
end = turn.end |
|
|
|
|
|
overlap_start = max(start, target_start_time) |
|
overlap_end = min(end, target_end_time) |
|
|
|
|
|
if overlap_end > overlap_start: |
|
overlap_duration = overlap_end - overlap_start |
|
|
|
|
|
if overlap_duration > max_overlap: |
|
max_overlap = overlap_duration |
|
best_match = speaker |
|
|
|
return best_match, max_overlap |
|
|
|
|
|
def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length): |
|
speaker_segments = {} |
|
|
|
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
start = turn.start |
|
end = turn.end |
|
|
|
|
|
if start < target_end_time and end > target_start_time: |
|
|
|
if start < target_start_time: |
|
|
|
speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end))) |
|
|
|
if end > target_end_time: |
|
|
|
speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length))) |
|
else: |
|
|
|
if end <= target_start_time or start >= target_end_time: |
|
speaker_segments.setdefault(speaker, []).append((start, end)) |
|
|
|
return speaker_segments |
|
|
|
|
|
def process_audio(target_audio, mixed_audio): |
|
print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}") |
|
|
|
|
|
time_dict = combine_audio_with_time(target_audio, mixed_audio) |
|
|
|
|
|
if isinstance(time_dict, str): |
|
return time_dict |
|
|
|
|
|
diarization_result = diarize_audio("final_output.wav") |
|
|
|
if isinstance(diarization_result, str) and diarization_result.startswith("错误"): |
|
return diarization_result |
|
else: |
|
|
|
final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 |
|
|
|
|
|
best_match, overlap_duration = find_best_matching_speaker( |
|
time_dict['start_time'], |
|
time_dict['end_time'], |
|
diarization_result |
|
) |
|
|
|
if best_match: |
|
|
|
speaker_segments = get_speaker_segments( |
|
diarization_result, |
|
time_dict['start_time'], |
|
time_dict['end_time'], |
|
final_audio_length |
|
) |
|
|
|
if best_match in speaker_segments: |
|
|
|
final_output = AudioSegment.empty() |
|
for segment in speaker_segments[best_match]: |
|
start_time_ms = int(segment[0] * 1000) |
|
end_time_ms = int(segment[1] * 1000) |
|
segment_audio = AudioSegment.from_wav("final_output.wav")[start_time_ms:end_time_ms] |
|
final_output += segment_audio |
|
|
|
|
|
final_output.export("final_combined_output.wav", format="wav") |
|
|
|
return "final_combined_output.wav" |
|
else: |
|
return "没有找到匹配的说话人时间段。" |
|
else: |
|
return "未找到匹配的说话人。" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# 🗣️ 音频拼接与说话人分类 🗣️ |
|
上传目标音频和混合音频,拼接并进行说话人分类。 |
|
结果包括目标说话人(SPEAKER_00)的时间段,已排除和截断目标录音时间段,并自动剪辑目标音频。 |
|
""") |
|
|
|
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频") |
|
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频") |
|
|
|
process_button = gr.Button("处理音频") |
|
|
|
|
|
output_audio = gr.Audio(label="剪辑后的音频") |
|
|
|
|
|
process_button.click( |
|
fn=process_audio, |
|
inputs=[target_audio_input, mixed_audio_input], |
|
outputs=[output_audio] |
|
) |
|
|
|
demo.launch(share=True) |
|
|