QLWD commited on
Commit
307d4dd
1 Parent(s): d579519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -48
app.py CHANGED
@@ -20,6 +20,11 @@ except Exception as e:
20
  print(f"Error initializing pipeline: {e}")
21
  pipeline = None
22
 
 
 
 
 
 
23
  # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
24
  def combine_audio_with_time(target_audio, mixed_audio):
25
  if pipeline is None:
@@ -62,57 +67,43 @@ def diarize_audio(temp_file):
62
 
63
  try:
64
  diarization = pipeline(temp_file)
 
 
 
 
65
  except Exception as e:
66
  return f"处理音频时出错: {e}"
67
 
68
- print(diarization)
69
-
70
- # 返回 diarization 类对象
71
- return diarization
72
-
73
- # 获取目标录音所在时间范围最大的说话人及其时间段
74
- def get_most_matched_speaker_segments(diarization_output, target_start_time, target_end_time, final_audio_length):
75
- # 用于存储说话人与目标音频重叠时间的字典
76
- speaker_overlaps = {}
77
 
78
- # 使用 itertracks 获取每个说话人的信息
79
- for speech_turn in diarization_output.itertracks(yield_label=True):
80
- start_seconds = speech_turn[0].start
81
- end_seconds = speech_turn[0].end
82
- label = speech_turn[1]
83
 
84
- # 计算目标音频与当前说话人时间段的重叠时间
85
- overlap_start = max(start_seconds, target_start_time)
86
- overlap_end = min(end_seconds, target_end_time)
87
- overlap_duration = max(0, overlap_end - overlap_start)
88
-
89
- # 如果有重叠,记录重叠时间
90
- if overlap_duration > 0:
91
- if label not in speaker_overlaps:
92
- speaker_overlaps[label] = {
93
- 'total_overlap': overlap_duration,
94
- 'segments': []
95
- }
96
  else:
97
- speaker_overlaps[label]['total_overlap'] += overlap_duration
98
-
99
- # 记录该说话人的原始时间段(排除目标音频时间段)
100
- if start_seconds < target_start_time:
101
- speaker_overlaps[label]['segments'].append((start_seconds, min(end_seconds, target_start_time)))
102
-
103
- if end_seconds > target_end_time:
104
- speaker_overlaps[label]['segments'].append((max(start_seconds, target_end_time), end_seconds))
105
-
106
- # 找到重叠时间最长的说话人
107
- if speaker_overlaps:
108
- most_matched_speaker = max(speaker_overlaps, key=lambda k: speaker_overlaps[k]['total_overlap'])
109
- return {most_matched_speaker: speaker_overlaps[most_matched_speaker]['segments']}
110
 
111
- return {}
112
 
113
  # 处理音频文件并返回输出
114
  def process_audio(target_audio, mixed_audio):
115
- # 打印文件路径,确保传入的文件有效
116
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
117
 
118
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
@@ -131,25 +122,29 @@ def process_audio(target_audio, mixed_audio):
131
  # 获取拼接后的音频长度
132
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
133
 
134
- # 获取目标录音所在时间范围最大的说话人时间段
135
- most_matched_speaker_segments = get_most_matched_speaker_segments(
136
  diarization_result,
137
  time_dict['start_time'],
138
  time_dict['end_time'],
139
  final_audio_length
140
  )
141
 
142
- if most_matched_speaker_segments:
143
- # 返回目标录音所在时间范围最大的说话人的时间段(排除目标音频时间段)
144
- return most_matched_speaker_segments
 
 
 
145
  else:
146
- return "没有找到与目标录音重叠的说话人时间段。"
147
 
148
  # Gradio 接口
149
  with gr.Blocks() as demo:
150
  gr.Markdown("""
151
  # 🗣️ 音频拼接与说话人分类 🗣️
152
- 上传目标音频和混合音频,拼接并进行说话人分类。结果包括与目标录音重叠时间最长的说话人的时间段(排除目标录音时间段)。
 
153
  """)
154
 
155
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
 
20
  print(f"Error initializing pipeline: {e}")
21
  pipeline = None
22
 
23
+ # 时间戳转换为秒
24
+ def timestamp_to_seconds(timestamp):
25
+ h, m, s = map(float, timestamp.split(':'))
26
+ return 3600 * h + 60 * m + s
27
+
28
  # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
29
  def combine_audio_with_time(target_audio, mixed_audio):
30
  if pipeline is None:
 
67
 
68
  try:
69
  diarization = pipeline(temp_file)
70
+ print("说话人分离结果:")
71
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
72
+ print(f"[{turn.start:.3f} --> {turn.end:.3f}] {speaker}")
73
+ return diarization
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
77
+ # 获取目标说话人的时间段(排除目标音频时间段)
78
+ def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
79
+ speaker_segments = {}
 
 
 
 
 
 
80
 
81
+ # 遍历所有说话人时间段
82
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
83
+ start = turn.start
84
+ end = turn.end
 
85
 
86
+ # 如果是目标说话人
87
+ if speaker == 'SPEAKER_00':
88
+ # 如果时间段与目标音频有重叠,需要截断
89
+ if start < target_end_time and end > target_start_time:
90
+ # 记录被截断的时间段
91
+ if start < target_start_time:
92
+ # 目标音频开始前的时间段
93
+ speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
94
+
95
+ if end > target_end_time:
96
+ # 目标音频结束后的时间段
97
+ speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
98
  else:
99
+ # 完全不与目标音频重叠的时间段
100
+ if end <= target_start_time or start >= target_end_time:
101
+ speaker_segments.setdefault(speaker, []).append((start, end))
 
 
 
 
 
 
 
 
 
 
102
 
103
+ return speaker_segments
104
 
105
  # 处理音频文件并返回输出
106
  def process_audio(target_audio, mixed_audio):
 
107
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
108
 
109
  # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
 
122
  # 获取拼接后的音频长度
123
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
124
 
125
+ # 获取目标说话人的时间段(排除目标音频时间段)
126
+ speaker_segments = get_speaker_segments(
127
  diarization_result,
128
  time_dict['start_time'],
129
  time_dict['end_time'],
130
  final_audio_length
131
  )
132
 
133
+ if speaker_segments and 'SPEAKER_00' in speaker_segments:
134
+ # 返回目标说话人的时间段(已排除和截断目标音频时间段)
135
+ return {
136
+ 'segments': speaker_segments['SPEAKER_00'],
137
+ 'total_duration': sum(end - start for start, end in speaker_segments['SPEAKER_00'])
138
+ }
139
  else:
140
+ return "没有找到SPEAKER_00的时间段。"
141
 
142
  # Gradio 接口
143
  with gr.Blocks() as demo:
144
  gr.Markdown("""
145
  # 🗣️ 音频拼接与说话人分类 🗣️
146
+ 上传目标音频和混合音频,拼接并进行说话人分类。
147
+ 结果包括目标说话人(SPEAKER_00)的时间段,已排除和截断目标录音时间段。
148
  """)
149
 
150
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")