QLWD commited on
Commit
46b30ee
1 Parent(s): a3cd2f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -42
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import torch
2
- import spaces
3
- import gradio as gr
4
  import os
5
  from pyannote.audio import Pipeline
6
  from pydub import AudioSegment
 
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
@@ -60,7 +59,6 @@ def combine_audio_with_time(target_audio, mixed_audio):
60
  return {"start_time": target_start_time, "end_time": target_end_time}
61
 
62
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
63
- @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
64
  def diarize_audio(temp_file):
65
  if pipeline is None:
66
  return "错误: 模型未初始化"
@@ -74,37 +72,34 @@ def diarize_audio(temp_file):
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
77
- # 获取指定说话人的时间段(排除目标音频时间段)
78
- def get_speaker_segments(diarization, speaker_name, target_start_time, target_end_time, final_audio_length):
79
- speaker_segments = {}
 
80
 
81
- # 遍历所有说话人时间段
82
  for turn, _, speaker in diarization.itertracks(yield_label=True):
83
  start = turn.start
84
  end = turn.end
85
 
86
- # 如果是目标说话人
87
- if speaker == speaker_name:
88
- # 如果时间段与目标音频有重叠,需要截断
89
- if start < target_end_time and end > target_start_time:
90
- # 记录被截断的时间段
91
- if start < target_start_time:
92
- # 目标音频开始前的时间段
93
- speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
94
-
95
- if end > target_end_time:
96
- # 目标音频结束后的时间段
97
- speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
98
- else:
99
- # 完全不与目标音频重叠的时间段
100
- if end <= target_start_time or start >= target_end_time:
101
- speaker_segments.setdefault(speaker, []).append((start, end))
102
 
103
- return speaker_segments
104
 
105
- # 处理音频文件并返回输出
106
- def process_audio(target_audio, mixed_audio, speaker_name):
107
- print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}, 提取说话人: {speaker_name}")
108
 
109
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
110
  time_dict = combine_audio_with_time(target_audio, mixed_audio)
@@ -122,45 +117,41 @@ def process_audio(target_audio, mixed_audio, speaker_name):
122
  # 获取拼接后的音频长度
123
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
124
 
125
- # 获取目标说话人的时间段(排除目标音频时间段)
126
- speaker_segments = get_speaker_segments(
127
- diarization_result,
128
- speaker_name,
129
  time_dict['start_time'],
130
  time_dict['end_time'],
131
- final_audio_length
132
  )
133
 
134
- if speaker_segments and speaker_name in speaker_segments:
135
- # 返回目标说话人的时间段(已排除和截断目标音频时间段)
136
  return {
137
- 'segments': speaker_segments[speaker_name],
138
- 'total_duration': sum(end - start for start, end in speaker_segments[speaker_name])
139
  }
140
  else:
141
- return f"没有找到 {speaker_name} 的时间段。"
142
 
143
  # Gradio 接口
144
  with gr.Blocks() as demo:
145
- gr.Markdown("""
146
  # 🗣️ 音频拼接与说话人分类 🗣️
147
  上传目标音频和混合音频,拼接并进行说话人分类。
148
- ���果包括指定说话人的时间段,已排除和截断目标录音时间段。
149
  """)
150
 
151
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
152
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
153
- speaker_name_input = gr.Textbox(label="请输入说话人名称(如 'SPEAKER_01')", value="SPEAKER_00")
154
 
155
  process_button = gr.Button("处理音频")
156
 
157
  # 输出结果
158
- diarization_output = gr.Textbox(label="说话人时间段")
159
 
160
  # 点击按钮时触发处理音频
161
  process_button.click(
162
  fn=process_audio,
163
- inputs=[target_audio_input, mixed_audio_input, speaker_name_input],
164
  outputs=[diarization_output]
165
  )
166
 
 
1
  import torch
 
 
2
  import os
3
  from pyannote.audio import Pipeline
4
  from pydub import AudioSegment
5
+ import gradio as gr
6
 
7
  # 获取 Hugging Face 认证令牌
8
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
 
59
  return {"start_time": target_start_time, "end_time": target_end_time}
60
 
61
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
 
62
  def diarize_audio(temp_file):
63
  if pipeline is None:
64
  return "错误: 模型未初始化"
 
72
  except Exception as e:
73
  return f"处理音频时出错: {e}"
74
 
75
+ # 查找最匹配的说话人
76
+ def find_best_matching_speaker(target_start_time, target_end_time, diarization):
77
+ best_match = None
78
+ max_overlap = 0
79
 
80
+ # 遍历所有说话人时间段,计算与目标音频的重叠部分
81
  for turn, _, speaker in diarization.itertracks(yield_label=True):
82
  start = turn.start
83
  end = turn.end
84
 
85
+ # 计算重叠部分的开始和结束时间
86
+ overlap_start = max(start, target_start_time)
87
+ overlap_end = min(end, target_end_time)
88
+
89
+ # 如果有重叠部分,计算重叠的持续时间
90
+ if overlap_end > overlap_start:
91
+ overlap_duration = overlap_end - overlap_start
92
+
93
+ # 如果当前重叠部分更大,则更新最匹配的说话人
94
+ if overlap_duration > max_overlap:
95
+ max_overlap = overlap_duration
96
+ best_match = speaker
 
 
 
 
97
 
98
+ return best_match, max_overlap
99
 
100
+ # 获取最匹配的说话人并返回其时间段
101
+ def process_audio(target_audio, mixed_audio):
102
+ print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
103
 
104
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
105
  time_dict = combine_audio_with_time(target_audio, mixed_audio)
 
117
  # 获取拼接后的音频长度
118
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
119
 
120
+ # 查找最匹配的说话人
121
+ best_match, overlap_duration = find_best_matching_speaker(
 
 
122
  time_dict['start_time'],
123
  time_dict['end_time'],
124
+ diarization_result
125
  )
126
 
127
+ if best_match:
 
128
  return {
129
+ 'best_matching_speaker': best_match,
130
+ 'overlap_duration': overlap_duration
131
  }
132
  else:
133
+ return "未找到匹配的说话人。"
134
 
135
  # Gradio 接口
136
  with gr.Blocks() as demo:
137
+ gr.Markdown("""
138
  # 🗣️ 音频拼接与说话人分类 🗣️
139
  上传目标音频和混合音频,拼接并进行说话人分类。
140
+ 结果包括最匹配的说话人以及重叠时长。
141
  """)
142
 
143
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
144
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
 
145
 
146
  process_button = gr.Button("处理音频")
147
 
148
  # 输出结果
149
+ diarization_output = gr.Textbox(label="最匹配的说话人及重叠时长")
150
 
151
  # 点击按钮时触发处理音频
152
  process_button.click(
153
  fn=process_audio,
154
+ inputs=[target_audio_input, mixed_audio_input],
155
  outputs=[diarization_output]
156
  )
157