File size: 7,698 Bytes
d8b6286
 
7d6a9ed
68f6bb9
6d3bc8f
081af9c
7d6a9ed
 
8a8a249
68f6bb9
8a8a249
081af9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388c913
081af9c
 
 
 
 
 
 
 
 
388c913
081af9c
 
388c913
081af9c
 
388c913
081af9c
 
 
388c913
081af9c
 
 
 
 
 
 
 
 
 
 
388c913
081af9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838e9df
081af9c
 
 
 
 
 
838e9df
081af9c
 
 
838e9df
081af9c
 
838e9df
081af9c
 
 
838e9df
 
 
081af9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838e9df
081af9c
 
 
 
 
 
 
 
 
 
68f6bb9
913c46d
68f6bb9
 
 
081af9c
68f6bb9
081af9c
68f6bb9
dedd569
081af9c
68f6bb9
081af9c
6d3bc8f
e827069
 
68f6bb9
6d3bc8f
68f6bb9
081af9c
6d3bc8f
e827069
68f6bb9
 
8a8a249
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import spaces
import gradio as gr
import os
from pyannote.audio import Pipeline
from pyannote.core import Annotation, Segment
from pydub import AudioSegment

# 获取 Hugging Face 认证令牌
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")

class AudioProcessor:
    def __init__(self):
        self.pipeline = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 尝试加载 pyannote 模型
        try:
            self.pipeline = Pipeline.from_pretrained(
                "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
            )
            self.pipeline.to(self.device)
            print("pyannote model loaded successfully.")
        except Exception as e:
            print(f"Error initializing pipeline: {e}")
            self.pipeline = None

    # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
    def combine_audio_with_time(self, target_audio, mixed_audio):
        if self.pipeline is None:
            return "错误: 模型未初始化"

        # 加载目标说话人的样本音频
        try:
            target_audio_segment = AudioSegment.from_wav(target_audio)
        except Exception as e:
            return f"加载目标音频时出错: {e}"

        # 加载混合音频
        try:
            mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
        except Exception as e:
            return f"加载混合音频时出错: {e}"

        # 记录目标说话人音频的时间点(精确到0.01秒)
        target_start_time = len(mixed_audio_segment) / 1000  # 秒为单位,精确到 0.01 秒

        # 目标音频的结束时间(拼接后的音频长度)
        target_end_time = target_start_time + len(target_audio_segment) / 1000  # 秒为单位

        # 将目标说话人的音频片段添加到混合音频的最后
        final_audio = mixed_audio_segment + target_audio_segment
        final_audio.export("final_output.wav", format="wav")

        # 返回目标音频的起始时间和结束时间
        return {"start_time": target_start_time, "end_time": target_end_time}

    # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
    @spaces.GPU(duration=60 * 2)  # 使用 GPU 加速,限制执行时间为 120 秒
    def diarize_audio(self, temp_file):
        if self.pipeline is None:
            return "错误: 模型未初始化"

        try:
            diarization = self.pipeline(temp_file)  # 返回 Annotation 对象
        except Exception as e:
            return f"处理音频时出错: {e}"

        return diarization  # 直接返回 Annotation 对象

    # 将时间戳转换为秒
    def timestamp_to_seconds(self, timestamp):
        try:
            h, m, s = map(float, timestamp.split(':'))
            return 3600 * h + 60 * m + s
        except ValueError as e:
            print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
            return None

    # 计算时间段的重叠部分(单位:秒)
    def calculate_overlap(self, start1, end1, start2, end2):
        overlap_start = max(start1, start2)
        overlap_end = min(end1, end2)
        overlap_duration = max(0, overlap_end - overlap_start)
        return overlap_duration

    # 获取目标时间段和说话人时间段的重叠比例
    def get_best_match(self, target_time, diarization_output):
        target_start_time = target_time['start_time']
        target_end_time = target_time['end_time']

        # 用于存储每个说话人时间段的重叠比例
        speaker_segments = []
        for segment, label in diarization_output.itertracks(yield_label=True):
            try:
                start_seconds = segment.start
                end_seconds = segment.end

                # 计算目标音频时间段和说话人时间段的重叠时间
                overlap = self.calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
                overlap_ratio = overlap / (end_seconds - start_seconds)

                # 记录说话人标签和重叠比例
                speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds))

            except Exception as e:
                print(f"处理行时出错: '{segment}'. 错误: {e}")

        # 按照重叠比例排序,返回重叠比例最大的一段
        best_match = max(speaker_segments, key=lambda x: x[1], default=None)
        return best_match

    # 获取该说话人除了目标语音时间段外的所有时间段
    def get_speaker_time_segments(self, diarization_output, target_time, speaker_label):
        remaining_segments = []

        # 遍历 diarization 输出,查找该说话人的所有时间段
        for segment, label in diarization_output.itertracks(yield_label=True):
            if label == speaker_label:
                start_seconds = segment.start
                end_seconds = segment.end

                # 计算与目标音频的重叠部分
                overlap_start = max(start_seconds, target_time['start_time'])
                overlap_end = min(end_seconds, target_time['end_time'])

                # 如果有重叠部分,排除重叠部分
                if overlap_start < overlap_end:
                    if start_seconds < overlap_start:
                        remaining_segments.append((start_seconds, overlap_start))
                    if overlap_end < end_seconds:
                        remaining_segments.append((overlap_end, end_seconds))
                else:
                    remaining_segments.append((start_seconds, end_seconds))

        return remaining_segments

    # 处理音频文件并返回输出
    def process_audio(self, target_audio, mixed_audio):
        # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
        time_dict = self.combine_audio_with_time(target_audio, mixed_audio)

        # 执行说话人分离
        diarization_result = self.diarize_audio("final_output.wav")

        if isinstance(diarization_result, Annotation) and diarization_result.startswith("错误"):
            return diarization_result, None  # 出错时返回错误信息
        else:
            # 获取最佳匹配的说话人标签和时间段
            best_match = self.get_best_match(time_dict, diarization_result)

            if best_match:
                speaker_label = best_match[0]  # 取出最佳匹配说话人的标签
                # 获取该说话人除了目标语音时间段外的所有时间段
                remaining_segments = self.get_speaker_time_segments(diarization_result, time_dict, speaker_label)
                return speaker_label, remaining_segments

# Gradio 接口
with gr.Blocks() as demo:
    gr.Markdown("""
    # 🗣️ 音频拼接与说话人分类 🗣️
    上传目标音频和混合音频,拼接并进行说话人分类。结果包括最佳匹配说话人的时间段(排除目标音频时间段)。
    """)

    mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
    target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")

    process_button = gr.Button("处理音频")

    # 输出结果
    diarization_output = gr.Textbox(label="最佳匹配说话人")
    time_range_output = gr.Textbox(label="最佳匹配时间段")

    # 点击按钮时触发处理音频
    process_button.click(
        fn=AudioProcessor().process_audio,
        inputs=[target_audio_input, mixed_audio_input],
        outputs=[diarization_output, time_range_output]
    )

demo.launch(share=True)