QLWD commited on
Commit
ba5925c
·
verified ·
1 Parent(s): b0766f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -65
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- import os
3
  import gradio as gr
 
4
  from pyannote.audio import Pipeline
5
  from pydub import AudioSegment
6
- from spaces import GPU
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
@@ -60,7 +60,7 @@ def combine_audio_with_time(target_audio, mixed_audio):
60
  return {"start_time": target_start_time, "end_time": target_end_time}
61
 
62
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
63
- @GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
64
  def diarize_audio(temp_file):
65
  if pipeline is None:
66
  return "错误: 模型未初始化"
@@ -74,31 +74,6 @@ def diarize_audio(temp_file):
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
77
- # 查找最匹配的说话人
78
- def find_best_matching_speaker(target_start_time, target_end_time, diarization):
79
- best_match = None
80
- max_overlap = 0
81
-
82
- # 遍历所有说话人时间段,计算与目标音频的重叠部分
83
- for turn, _, speaker in diarization.itertracks(yield_label=True):
84
- start = turn.start
85
- end = turn.end
86
-
87
- # 计算重叠部分的开始和结束时间
88
- overlap_start = max(start, target_start_time)
89
- overlap_end = min(end, target_end_time)
90
-
91
- # 如果有重叠部分,计算重叠的持续时间
92
- if overlap_end > overlap_start:
93
- overlap_duration = overlap_end - overlap_start
94
-
95
- # 如果当前重叠部分更大,则更新最匹配的说话人
96
- if overlap_duration > max_overlap:
97
- max_overlap = overlap_duration
98
- best_match = speaker
99
-
100
- return best_match, max_overlap
101
-
102
  # 获取目标说话人的时间段(排除目标音频时间段)
103
  def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
104
  speaker_segments = {}
@@ -108,23 +83,34 @@ def get_speaker_segments(diarization, target_start_time, target_end_time, final_
108
  start = turn.start
109
  end = turn.end
110
 
111
- # 如果时间段与目标音频有重叠,需要截断
112
- if start < target_end_time and end > target_start_time:
113
- # 记录被截断的时间段
114
- if start < target_start_time:
115
- # 目标音频开始前的时间段
116
- speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
117
-
118
- if end > target_end_time:
119
- # 目标音频结束后的时间段
120
- speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
121
- else:
122
- # 完全不与目标音频重叠的时间段
123
- if end <= target_start_time or start >= target_end_time:
124
- speaker_segments.setdefault(speaker, []).append((start, end))
 
 
125
 
126
  return speaker_segments
127
 
 
 
 
 
 
 
 
 
 
128
  # 处理音频文件并返回输出
129
  def process_audio(target_audio, mixed_audio):
130
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
@@ -145,39 +131,37 @@ def process_audio(target_audio, mixed_audio):
145
  # 获取拼接后的音频长度
146
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
147
 
148
- # 查找最匹配的说话人
149
- best_match, overlap_duration = find_best_matching_speaker(
 
150
  time_dict['start_time'],
151
  time_dict['end_time'],
152
- diarization_result
153
  )
154
 
155
- if best_match:
156
- # 获取目标说话人的时间段(排除和截断目标音频时间段)
157
- speaker_segments = get_speaker_segments(
158
- diarization_result,
159
- time_dict['start_time'],
160
- time_dict['end_time'],
161
- final_audio_length
162
- )
163
 
164
- if best_match in speaker_segments:
165
- return {
166
- 'best_matching_speaker': best_match,
167
- 'overlap_duration': overlap_duration,
168
- 'segments': speaker_segments[best_match]
169
- }
170
- else:
171
- return "没有找到匹配的说话人时间段。"
 
172
  else:
173
- return "未找到匹配的说话人。"
174
 
175
  # Gradio 接口
176
  with gr.Blocks() as demo:
177
  gr.Markdown("""
178
  # 🗣️ 音频拼接与说话人分类 🗣️
179
  上传目标音频和混合音频,拼接并进行说话人分类。
180
- 结果包括最匹配的说话人及其时间段,已排除和截断目标录音时间段。
181
  """)
182
 
183
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
@@ -186,13 +170,13 @@ with gr.Blocks() as demo:
186
  process_button = gr.Button("处理音频")
187
 
188
  # 输出结果
189
- diarization_output = gr.Textbox(label="说话人时间段")
190
 
191
  # 点击按钮时触发处理音频
192
  process_button.click(
193
  fn=process_audio,
194
  inputs=[target_audio_input, mixed_audio_input],
195
- outputs=[diarization_output]
196
  )
197
 
198
  demo.launch(share=True)
 
1
  import torch
2
+ import spaces
3
  import gradio as gr
4
+ import os
5
  from pyannote.audio import Pipeline
6
  from pydub import AudioSegment
 
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
 
60
  return {"start_time": target_start_time, "end_time": target_end_time}
61
 
62
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
63
+ @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
64
  def diarize_audio(temp_file):
65
  if pipeline is None:
66
  return "错误: 模型未初始化"
 
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # 获取目标说话人的时间段(排除目标音频时间段)
78
  def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
79
  speaker_segments = {}
 
83
  start = turn.start
84
  end = turn.end
85
 
86
+ # 如果是目标说话人
87
+ if speaker == 'SPEAKER_00':
88
+ # 如果时间段与目标音频有重叠,需要截断
89
+ if start < target_end_time and end > target_start_time:
90
+ # 记录被截断的时间段
91
+ if start < target_start_time:
92
+ # 目标音频开始前的时间段
93
+ speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
94
+
95
+ if end > target_end_time:
96
+ # 目标音频结束后的时间段
97
+ speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
98
+ else:
99
+ # 完全不与目标音频重叠的时间段
100
+ if end <= target_start_time or start >= target_end_time:
101
+ speaker_segments.setdefault(speaker, []).append((start, end))
102
 
103
  return speaker_segments
104
 
105
+ # 剪辑音频函数:根据时间段剪辑音频
106
+ def clip_audio(audio_segment, segments):
107
+ clips = []
108
+ for start, end in segments:
109
+ start_ms = int(start * 1000) # 毫秒
110
+ end_ms = int(end * 1000) # 毫秒
111
+ clips.append(audio_segment[start_ms:end_ms])
112
+ return clips
113
+
114
  # 处理音频文件并返回输出
115
  def process_audio(target_audio, mixed_audio):
116
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
 
131
  # 获取拼接后的音频长度
132
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
133
 
134
+ # 获取目标说话人的时间段(排除和截断目标音频时间段)
135
+ speaker_segments = get_speaker_segments(
136
+ diarization_result,
137
  time_dict['start_time'],
138
  time_dict['end_time'],
139
+ final_audio_length
140
  )
141
 
142
+ if speaker_segments and 'SPEAKER_00' in speaker_segments:
143
+ # 剪辑目标说话人的音频片段
144
+ final_audio_segment = AudioSegment.from_wav("final_output.wav")
145
+ clips = clip_audio(final_audio_segment, speaker_segments['SPEAKER_00'])
 
 
 
 
146
 
147
+ # 将剪辑后的音频片段导出为多个文件
148
+ output_files = []
149
+ for i, clip in enumerate(clips):
150
+ clip_path = f"speaker_00_clip_{i + 1}.wav"
151
+ clip.export(clip_path, format="wav")
152
+ output_files.append(clip_path)
153
+
154
+ # 返回剪辑后的音频文件路径
155
+ return output_files
156
  else:
157
+ return "没有找到SPEAKER_00的时间段。"
158
 
159
  # Gradio 接口
160
  with gr.Blocks() as demo:
161
  gr.Markdown("""
162
  # 🗣️ 音频拼接与说话人分类 🗣️
163
  上传目标音频和混合音频,拼接并进行说话人分类。
164
+ 结果包括目标说话人(SPEAKER_00)的时间段,已排除和截断目标录音时间段,并自动剪辑目标音频。
165
  """)
166
 
167
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
 
170
  process_button = gr.Button("处理音频")
171
 
172
  # 输出结果
173
+ output_audio = gr.Audio(label="剪辑后的音频")
174
 
175
  # 点击按钮时触发处理音频
176
  process_button.click(
177
  fn=process_audio,
178
  inputs=[target_audio_input, mixed_audio_input],
179
+ outputs=[output_audio]
180
  )
181
 
182
  demo.launch(share=True)