QLWD commited on
Commit
388c913
1 Parent(s): dedd569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -40
app.py CHANGED
@@ -25,10 +25,6 @@ def combine_audio_with_time(target_audio, mixed_audio):
25
  if pipeline is None:
26
  return "错误: 模型未初始化"
27
 
28
- # 打印文件路径,确保文件正确传递
29
- print(f"目标音频文件路径: {target_audio}")
30
- print(f"混合音频文件路径: {mixed_audio}")
31
-
32
  # 加载目标说话人的样本音频
33
  try:
34
  target_audio_segment = AudioSegment.from_wav(target_audio)
@@ -68,31 +64,43 @@ def diarize_audio(temp_file):
68
  # 返回 diarization 输出
69
  return str(diarization)
70
 
71
- # 生成标签文件的函数
72
- def generate_labels_from_diarization(diarization_output):
73
- labels_path = 'labels.txt'
74
- successful_lines = 0
75
-
76
- try:
77
- with open(labels_path, 'w') as outfile:
78
- lines = diarization_output.strip().split('\n')
79
- for line in lines:
80
- try:
81
- parts = line.strip()[1:-1].split(' --> ')
82
- start_time = parts[0].strip()
83
- end_time = parts[1].split(']')[0].strip()
84
- label = line.split()[-1].strip()
85
- start_seconds = timestamp_to_seconds(start_time)
86
- end_seconds = timestamp_to_seconds(end_time)
87
- outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n")
88
- successful_lines += 1
89
- except Exception as e:
90
- print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
91
- print(f"成功处理了 {successful_lines} 行。")
92
- return labels_path if successful_lines > 0 else None
93
- except Exception as e:
94
- print(f"写入文件时出错: {e}")
95
- return None
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # 将时间戳转换为秒
98
  def timestamp_to_seconds(timestamp):
@@ -105,9 +113,6 @@ def timestamp_to_seconds(timestamp):
105
 
106
  # 处理音频文件并返回输出
107
  def process_audio(target_audio, mixed_audio):
108
- # 打印文件路径,确保传入的文件有效
109
- print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
110
-
111
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
112
  time_dict = combine_audio_with_time(target_audio, mixed_audio)
113
 
@@ -117,32 +122,30 @@ def process_audio(target_audio, mixed_audio):
117
  if diarization_result.startswith("错误"):
118
  return diarization_result, None, None # 出错时返回错误信息
119
  else:
120
- # 生成标签文件
121
- label_file = generate_labels_from_diarization(diarization_result)
122
- return diarization_result, label_file, time_dict # 返回说话人分离结果、标签文件和目标音频的时间段
123
 
124
  # Gradio 接口
125
  with gr.Blocks() as demo:
126
  gr.Markdown("""
127
  # 🗣️ 音频拼接与说话人分类 🗣️
128
- 上传目标说话人音频和混合音频,拼接并进行说话人分类。结果包括说话人分离输出、标签文件和目标音频的时间段。
129
  """)
130
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
131
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
132
 
133
-
134
  process_button = gr.Button("处理音频")
135
 
136
  # 输出结果
137
  diarization_output = gr.Textbox(label="说话人分离结果")
138
- label_file_link = gr.File(label="下载标签文件")
139
- time_range_output = gr.Textbox(label="目标音频时间段")
140
 
141
  # 点击按钮时触发处理音频
142
  process_button.click(
143
  fn=process_audio,
144
  inputs=[target_audio_input, mixed_audio_input],
145
- outputs=[diarization_output, label_file_link, time_range_output]
146
  )
147
 
148
  demo.launch(share=True)
 
25
  if pipeline is None:
26
  return "错误: 模型未初始化"
27
 
 
 
 
 
28
  # 加载目标说话人的样本音频
29
  try:
30
  target_audio_segment = AudioSegment.from_wav(target_audio)
 
64
  # 返回 diarization 输出
65
  return str(diarization)
66
 
67
+ # 计算时间段的重叠部分(单位:秒)
68
+ def calculate_overlap(start1, end1, start2, end2):
69
+ overlap_start = max(start1, start2)
70
+ overlap_end = min(end1, end2)
71
+ overlap_duration = max(0, overlap_end - overlap_start)
72
+ return overlap_duration
73
+
74
+ # 获取目标时间段和说话人时间段的重叠比例
75
+ def get_best_match(target_time, diarization_output):
76
+ target_start_time = target_time['start_time']
77
+ target_end_time = target_time['end_time']
78
+
79
+ # 假设 diarization_output 是一个列表,包含说话人时间段和标签
80
+ speaker_segments = []
81
+ for line in diarization_output.strip().split('\n'):
82
+ try:
83
+ parts = line.strip()[1:-1].split(' --> ')
84
+ start_time = parts[0].strip()
85
+ end_time = parts[1].split(']')[0].strip()
86
+ label = line.split()[-1].strip()
87
+
88
+ start_seconds = timestamp_to_seconds(start_time)
89
+ end_seconds = timestamp_to_seconds(end_time)
90
+
91
+ # 计算目标音频时间段和说话人时间段的重叠时间
92
+ overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
93
+ overlap_ratio = overlap / (end_seconds - start_seconds)
94
+
95
+ # 记录说话人标签和重叠比例
96
+ speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds))
97
+
98
+ except Exception as e:
99
+ print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
100
+
101
+ # 按照重叠比例排序,返回重叠比例最大的一段
102
+ best_match = max(speaker_segments, key=lambda x: x[1], default=None)
103
+ return best_match
104
 
105
  # 将时间戳转换为秒
106
  def timestamp_to_seconds(timestamp):
 
113
 
114
  # 处理音频文件并返回输出
115
  def process_audio(target_audio, mixed_audio):
 
 
 
116
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
117
  time_dict = combine_audio_with_time(target_audio, mixed_audio)
118
 
 
122
  if diarization_result.startswith("错误"):
123
  return diarization_result, None, None # 出错时返回错误信息
124
  else:
125
+ # 获取最佳匹配的说话人时间段
126
+ best_match = get_best_match(time_dict, diarization_result)
127
+ return diarization_result, best_match # 返回说话人分离结果和最佳匹配的说话人时间段
128
 
129
  # Gradio 接口
130
  with gr.Blocks() as demo:
131
  gr.Markdown("""
132
  # 🗣️ 音频拼接与说话人分类 🗣️
133
+ 上传目标音频和混合音频,拼接并进行说话人分类。结果包括说话人分离输出、最佳匹配说话人时间段。
134
  """)
135
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
136
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
137
 
 
138
  process_button = gr.Button("处理音频")
139
 
140
  # 输出结果
141
  diarization_output = gr.Textbox(label="说话人分离结果")
142
+ best_match_output = gr.Textbox(label="最佳匹配说话人时间段")
 
143
 
144
  # 点击按钮时触发处理音频
145
  process_button.click(
146
  fn=process_audio,
147
  inputs=[target_audio_input, mixed_audio_input],
148
+ outputs=[diarization_output, best_match_output]
149
  )
150
 
151
  demo.launch(share=True)