QLWD commited on
Commit
cb823a5
·
verified ·
1 Parent(s): 0bca3ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -64,8 +64,7 @@ def diarize_audio(temp_file):
64
  diarization = pipeline(temp_file)
65
  except Exception as e:
66
  return f"处理音频时出错: {e}"
67
-
68
- print(diarization)
69
  # 返回 diarization 类对象
70
  return diarization
71
 
@@ -85,31 +84,58 @@ def calculate_overlap(start1, end1, start2, end2):
85
  overlap_duration = max(0, overlap_end - overlap_start)
86
  return overlap_duration
87
 
88
- # 获取目标时间段和说话人时间段的重叠比例
89
  def get_matching_segments(target_time, diarization_output):
90
  target_start_time = target_time['start_time']
91
  target_end_time = target_time['end_time']
92
 
93
- # 获取该说话人时间段的信息,排除目标录音时间段
94
- speaker_segments = {}
95
- for speech_turn in diarization_output.itertracks(yield_label=True): # 使用 itertracks 获取每个说话人的信息
 
96
  start_seconds = speech_turn[0].start
97
  end_seconds = speech_turn[0].end
98
  label = speech_turn[1]
99
 
100
- # 计算目标音频时间段和说话人时间段的重叠时间
101
  overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
102
 
103
- # 如果存在重叠,排除目标音频时间段
104
  if overlap > 0:
105
- if label not in speaker_segments:
106
- speaker_segments[label] = []
107
-
108
- # 如果时间段与目标音频有重叠,跳过该时间段
109
- if start_seconds >= target_end_time or end_seconds <= target_start_time:
110
- speaker_segments[label].append((start_seconds, end_seconds))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- return speaker_segments
 
 
 
 
113
 
114
  # 处理音频文件并返回输出
115
  def process_audio(target_audio, mixed_audio):
@@ -125,12 +151,12 @@ def process_audio(target_audio, mixed_audio):
125
  if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
126
  return diarization_result, None # 出错时返回错误信息
127
  else:
128
- # 获取该说话人的所有匹配时间段(排除目标音频时间段)
129
  matching_segments = get_matching_segments(time_dict, diarization_result)
130
 
131
  if matching_segments:
132
- # 返回匹配的说话人标签和他们的时间段
133
- return matching_segments
134
  else:
135
  return "没有找到匹配的说话人时间段。"
136
 
@@ -138,7 +164,7 @@ def process_audio(target_audio, mixed_audio):
138
  with gr.Blocks() as demo:
139
  gr.Markdown("""
140
  # 🗣️ 音频拼接与说话人分类 🗣️
141
- 上传目标音频和混合音频,拼接并进行说话人分类。结果包括所有匹配说话人的时间段(排除目标录音时间段)。
142
  """)
143
 
144
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
 
64
  diarization = pipeline(temp_file)
65
  except Exception as e:
66
  return f"处理音频时出错: {e}"
67
+
 
68
  # 返回 diarization 类对象
69
  return diarization
70
 
 
84
  overlap_duration = max(0, overlap_end - overlap_start)
85
  return overlap_duration
86
 
87
+ # 获取所有说话人时间段,排除目标音频时间段
88
  def get_matching_segments(target_time, diarization_output):
89
  target_start_time = target_time['start_time']
90
  target_end_time = target_time['end_time']
91
 
92
+ # 记录每个说话人与目标音频的重叠时间
93
+ speaker_overlap = {}
94
+
95
+ for speech_turn in diarization_output.itertracks(yield_label=True):
96
  start_seconds = speech_turn[0].start
97
  end_seconds = speech_turn[0].end
98
  label = speech_turn[1]
99
 
100
+ # 计算目标音频时间段与该说话人时间段的重叠时间
101
  overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
102
 
 
103
  if overlap > 0:
104
+ if label not in speaker_overlap:
105
+ speaker_overlap[label] = 0
106
+ speaker_overlap[label] += overlap
107
+
108
+ # 找到与目标音频时间段重叠最多的说话人
109
+ max_overlap_speaker = max(speaker_overlap, key=speaker_overlap.get, default=None)
110
+
111
+ if max_overlap_speaker is None:
112
+ return "没有找到匹配的说话人"
113
+
114
+ # 获取该说话人的所有时间段,排除目标音频的时间段
115
+ speaker_segments = []
116
+ for speech_turn in diarization_output.itertracks(yield_label=True):
117
+ start_seconds = speech_turn[0].start
118
+ end_seconds = speech_turn[0].end
119
+ label = speech_turn[1]
120
+
121
+ if label == max_overlap_speaker:
122
+ # 计算目标音频时间段与该说话人时间段的重叠时间
123
+ overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
124
+ if overlap == 0: # 如果没有重叠,则保留该时间段
125
+ speaker_segments.append((start_seconds, end_seconds))
126
+
127
+ # 转换时间段为更易读的格式(例如:00:00:03.895 --> 00:00:04.367)
128
+ formatted_segments = [
129
+ f"{format_time(segment[0])} --> {format_time(segment[1])}" for segment in speaker_segments
130
+ ]
131
+
132
+ return formatted_segments
133
 
134
+ # 格式化时间(秒 -> hh:mm:ss.xxx)
135
+ def format_time(seconds):
136
+ mins, secs = divmod(seconds, 60)
137
+ hrs, mins = divmod(mins, 60)
138
+ return f"{int(hrs):02}:{int(mins):02}:{secs:06.3f}"
139
 
140
  # 处理音频文件并返回输出
141
  def process_audio(target_audio, mixed_audio):
 
151
  if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
152
  return diarization_result, None # 出错时返回错误信息
153
  else:
154
+ # 获取重叠最多的说话人的所有匹配时间段
155
  matching_segments = get_matching_segments(time_dict, diarization_result)
156
 
157
  if matching_segments:
158
+ # 返回匹配的时间段
159
+ return "\n".join(matching_segments)
160
  else:
161
  return "没有找到匹配的说话人时间段。"
162
 
 
164
  with gr.Blocks() as demo:
165
  gr.Markdown("""
166
  # 🗣️ 音频拼接与说话人分类 🗣️
167
+ 上传目标音频和混合音频,拼接并进行说话人分类。结果包括匹配重叠最多的说话人的时间段。
168
  """)
169
 
170
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")