File size: 7,009 Bytes
d8b6286 7d6a9ed 68f6bb9 6d3bc8f 7d6a9ed 8a8a249 68f6bb9 3152b48 8a8a249 3152b48 cb823a5 3152b48 cb823a5 0bca3ec 3152b48 cb823a5 3152b48 cb823a5 3152b48 0bca3ec cb823a5 3152b48 cb823a5 3152b48 cb823a5 0bca3ec 21019ce 0bca3ec cb823a5 0bca3ec 68f6bb9 913c46d 68f6bb9 cb823a5 68f6bb9 3152b48 68f6bb9 dedd569 3152b48 68f6bb9 3152b48 6d3bc8f 0bca3ec 68f6bb9 6d3bc8f 68f6bb9 3152b48 6d3bc8f 0bca3ec 68f6bb9 1cdea95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import spaces
import gradio as gr
import os
from pyannote.audio import Pipeline
from pydub import AudioSegment
# 获取 Hugging Face 认证令牌
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
pipeline = None
# 尝试加载 pyannote 模型
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
# 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
def combine_audio_with_time(target_audio, mixed_audio):
if pipeline is None:
return "错误: 模型未初始化"
# 打印文件路径,确保文件正确传递
print(f"目标音频文件路径: {target_audio}")
print(f"混合音频文件路径: {mixed_audio}")
# 加载目标说话人的样本音频
try:
target_audio_segment = AudioSegment.from_wav(target_audio)
except Exception as e:
return f"加载目标音频时出错: {e}"
# 加载混合音频
try:
mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
except Exception as e:
return f"加载混合音频时出错: {e}"
# 记录目标说话人音频的时间点(精确到0.01秒)
target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
# 目标音频的结束时间(拼接后的音频长度)
target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位
# 将目标说话人的音频片段添加到混合音频的最后
final_audio = mixed_audio_segment + target_audio_segment
final_audio.export("final_output.wav", format="wav")
# 返回目标音频的起始时间和结束时间
return {"start_time": target_start_time, "end_time": target_end_time}
# 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
@spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
def diarize_audio(temp_file):
if pipeline is None:
return "错误: 模型未初始化"
try:
diarization = pipeline(temp_file)
except Exception as e:
return f"处理音频时出错: {e}"
# 返回 diarization 类对象
return diarization
# 将时间戳转换为秒
def timestamp_to_seconds(timestamp):
try:
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
except ValueError as e:
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
return None
# 计算时间段的重叠部分(单位:秒)
def calculate_overlap(start1, end1, start2, end2):
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
overlap_duration = max(0, overlap_end - overlap_start)
return overlap_duration
# 获取所有说话人时间段,排除目标音频时间段
def get_matching_segments(target_time, diarization_output):
target_start_time = target_time['start_time']
target_end_time = target_time['end_time']
# 记录每个说话人与目标音频的重叠时间
speaker_overlap = {}
for speech_turn in diarization_output.itertracks(yield_label=True):
start_seconds = speech_turn[0].start
end_seconds = speech_turn[0].end
label = speech_turn[1]
# 计算目标音频时间段与该说话人时间段的重叠时间
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
if overlap > 0:
if label not in speaker_overlap:
speaker_overlap[label] = 0
speaker_overlap[label] += overlap
# 找到与目标音频时间段重叠最多的说话人
max_overlap_speaker = max(speaker_overlap, key=speaker_overlap.get, default=None)
if max_overlap_speaker is None:
return "没有找到匹配的说话人"
# 获取该说话人的所有时间段,排除目标音频的时间段
speaker_segments = []
for speech_turn in diarization_output.itertracks(yield_label=True):
start_seconds = speech_turn[0].start
end_seconds = speech_turn[0].end
label = speech_turn[1]
if label == max_overlap_speaker:
# 计算目标音频时间段与该说话人时间段的重叠时间
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
if overlap == 0: # 如果没有重叠,则保留该时间段
speaker_segments.append((start_seconds, end_seconds))
# 转换时间段为更易读的格式(例如:00:00:03.895 --> 00:00:04.367)
formatted_segments = [
f"{format_time(segment[0])} --> {format_time(segment[1])}" for segment in speaker_segments
]
return formatted_segments
# 格式化时间(秒 -> hh:mm:ss.xxx)
def format_time(seconds):
mins, secs = divmod(seconds, 60)
hrs, mins = divmod(mins, 60)
return f"{int(hrs):02}:{int(mins):02}:{secs:06.3f}"
# 处理音频文件并返回输出
def process_audio(target_audio, mixed_audio):
# 打印文件路径,确保传入的文件有效
print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
# 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
time_dict = combine_audio_with_time(target_audio, mixed_audio)
# 执行说话人分离
diarization_result = diarize_audio("final_output.wav")
if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
return diarization_result, None # 出错时返回错误信息
else:
# 获取重叠最多的说话人的所有匹配时间段
matching_segments = get_matching_segments(time_dict, diarization_result)
if matching_segments:
# 返回匹配的时间段
return "\n".join(matching_segments)
else:
return "没有找到匹配的说话人时间段。"
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("""
# 🗣️ 音频拼接与说话人分类 🗣️
上传目标音频和混合音频,拼接并进行说话人分类。结果包括匹配重叠最多的说话人的时间段。
""")
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
process_button = gr.Button("处理音频")
# 输出结果
diarization_output = gr.Textbox(label="匹配的说话人时间段")
# 点击按钮时触发处理音频
process_button.click(
fn=process_audio,
inputs=[target_audio_input, mixed_audio_input],
outputs=[diarization_output]
)
demo.launch(share=True)
|