Update app.py
Browse files
app.py
CHANGED
@@ -64,8 +64,7 @@ def diarize_audio(temp_file):
|
|
64 |
diarization = pipeline(temp_file)
|
65 |
except Exception as e:
|
66 |
return f"处理音频时出错: {e}"
|
67 |
-
|
68 |
-
print(diarization)
|
69 |
# 返回 diarization 类对象
|
70 |
return diarization
|
71 |
|
@@ -85,31 +84,58 @@ def calculate_overlap(start1, end1, start2, end2):
|
|
85 |
overlap_duration = max(0, overlap_end - overlap_start)
|
86 |
return overlap_duration
|
87 |
|
88 |
-
#
|
89 |
def get_matching_segments(target_time, diarization_output):
|
90 |
target_start_time = target_time['start_time']
|
91 |
target_end_time = target_time['end_time']
|
92 |
|
93 |
-
#
|
94 |
-
|
95 |
-
|
|
|
96 |
start_seconds = speech_turn[0].start
|
97 |
end_seconds = speech_turn[0].end
|
98 |
label = speech_turn[1]
|
99 |
|
100 |
-
#
|
101 |
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
|
102 |
|
103 |
-
# 如果存在重叠,排除目标音频时间段
|
104 |
if overlap > 0:
|
105 |
-
if label not in
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
|
114 |
# 处理音频文件并返回输出
|
115 |
def process_audio(target_audio, mixed_audio):
|
@@ -125,12 +151,12 @@ def process_audio(target_audio, mixed_audio):
|
|
125 |
if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
|
126 |
return diarization_result, None # 出错时返回错误信息
|
127 |
else:
|
128 |
-
#
|
129 |
matching_segments = get_matching_segments(time_dict, diarization_result)
|
130 |
|
131 |
if matching_segments:
|
132 |
-
#
|
133 |
-
return matching_segments
|
134 |
else:
|
135 |
return "没有找到匹配的说话人时间段。"
|
136 |
|
@@ -138,7 +164,7 @@ def process_audio(target_audio, mixed_audio):
|
|
138 |
with gr.Blocks() as demo:
|
139 |
gr.Markdown("""
|
140 |
# 🗣️ 音频拼接与说话人分类 🗣️
|
141 |
-
|
142 |
""")
|
143 |
|
144 |
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
|
|
|
64 |
diarization = pipeline(temp_file)
|
65 |
except Exception as e:
|
66 |
return f"处理音频时出错: {e}"
|
67 |
+
|
|
|
68 |
# 返回 diarization 类对象
|
69 |
return diarization
|
70 |
|
|
|
84 |
overlap_duration = max(0, overlap_end - overlap_start)
|
85 |
return overlap_duration
|
86 |
|
87 |
+
# 获取所有说话人时间段,排除目标音频时间段
|
88 |
def get_matching_segments(target_time, diarization_output):
|
89 |
target_start_time = target_time['start_time']
|
90 |
target_end_time = target_time['end_time']
|
91 |
|
92 |
+
# 记录每个说话人与目标音频的重叠时间
|
93 |
+
speaker_overlap = {}
|
94 |
+
|
95 |
+
for speech_turn in diarization_output.itertracks(yield_label=True):
|
96 |
start_seconds = speech_turn[0].start
|
97 |
end_seconds = speech_turn[0].end
|
98 |
label = speech_turn[1]
|
99 |
|
100 |
+
# 计算目标音频时间段与该说话人时间段的重叠时间
|
101 |
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
|
102 |
|
|
|
103 |
if overlap > 0:
|
104 |
+
if label not in speaker_overlap:
|
105 |
+
speaker_overlap[label] = 0
|
106 |
+
speaker_overlap[label] += overlap
|
107 |
+
|
108 |
+
# 找到与目标音频时间段重叠最多的说话人
|
109 |
+
max_overlap_speaker = max(speaker_overlap, key=speaker_overlap.get, default=None)
|
110 |
+
|
111 |
+
if max_overlap_speaker is None:
|
112 |
+
return "没有找到匹配的说话人"
|
113 |
+
|
114 |
+
# 获取该说话人的所有时间段,排除目标音频的时间段
|
115 |
+
speaker_segments = []
|
116 |
+
for speech_turn in diarization_output.itertracks(yield_label=True):
|
117 |
+
start_seconds = speech_turn[0].start
|
118 |
+
end_seconds = speech_turn[0].end
|
119 |
+
label = speech_turn[1]
|
120 |
+
|
121 |
+
if label == max_overlap_speaker:
|
122 |
+
# 计算目标音频时间段与该说话人时间段的重叠时间
|
123 |
+
overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
|
124 |
+
if overlap == 0: # 如果没有重叠,则保留该时间段
|
125 |
+
speaker_segments.append((start_seconds, end_seconds))
|
126 |
+
|
127 |
+
# 转换时间段为更易读的格式(例如:00:00:03.895 --> 00:00:04.367)
|
128 |
+
formatted_segments = [
|
129 |
+
f"{format_time(segment[0])} --> {format_time(segment[1])}" for segment in speaker_segments
|
130 |
+
]
|
131 |
+
|
132 |
+
return formatted_segments
|
133 |
|
134 |
+
# 格式化时间(秒 -> hh:mm:ss.xxx)
|
135 |
+
def format_time(seconds):
|
136 |
+
mins, secs = divmod(seconds, 60)
|
137 |
+
hrs, mins = divmod(mins, 60)
|
138 |
+
return f"{int(hrs):02}:{int(mins):02}:{secs:06.3f}"
|
139 |
|
140 |
# 处理音频文件并返回输出
|
141 |
def process_audio(target_audio, mixed_audio):
|
|
|
151 |
if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
|
152 |
return diarization_result, None # 出错时返回错误信息
|
153 |
else:
|
154 |
+
# 获取重叠最多的说话人的所有匹配时间段
|
155 |
matching_segments = get_matching_segments(time_dict, diarization_result)
|
156 |
|
157 |
if matching_segments:
|
158 |
+
# 返回匹配的时间段
|
159 |
+
return "\n".join(matching_segments)
|
160 |
else:
|
161 |
return "没有找到匹配的说话人时间段。"
|
162 |
|
|
|
164 |
with gr.Blocks() as demo:
|
165 |
gr.Markdown("""
|
166 |
# 🗣️ 音频拼接与说话人分类 🗣️
|
167 |
+
上传目标音频和混合音频,拼接并进行说话人分类。结果包括匹配重叠最多的说话人的时间段。
|
168 |
""")
|
169 |
|
170 |
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
|