QLWD commited on
Commit
68f6bb9
1 Parent(s): 913c46d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -15
app.py CHANGED
@@ -1,8 +1,27 @@
1
  import gradio as gr
 
2
  from pydub import AudioSegment
 
 
3
 
4
- # 处理音频函数
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def combine_audio_with_time(target_audio, mixed_audio):
 
 
 
6
  # 加载目标说话人的样本音频
7
  target_audio_segment = AudioSegment.from_wav(target_audio.name)
8
 
@@ -20,18 +39,94 @@ def combine_audio_with_time(target_audio, mixed_audio):
20
 
21
  return "final_output.wav", target_start_time
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Gradio 接口
24
- interface = gr.Interface(
25
- fn=combine_audio_with_time,
26
- inputs=[
27
- gr.File(label="目标说话人音频"), # 上传目标说话人音频
28
- gr.File(label="混合音频") # 上传混合音频
29
- ],
30
- outputs=[
31
- gr.Audio(label="输出音频"), # 返回拼接后的音频文件
32
- gr.Textbox(label="目标音频起始时间") # 显示目标音频的起始时间
33
- ],
34
- live=False
35
- )
36
-
37
- interface.launch()
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
  from pydub import AudioSegment
4
+ from pyannote.audio.pipelines import SpeakerDiarization
5
+ import torch
6
 
7
+ # 初始化 pyannote/speaker-diarization 模型
8
+ HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
9
+ pipeline = None
10
+ try:
11
+ pipeline = SpeakerDiarization.from_pretrained(
12
+ "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
13
+ )
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ pipeline.to(device)
16
+ except Exception as e:
17
+ print(f"Error initializing pipeline: {e}")
18
+ pipeline = None
19
+
20
+ # 音频处理函数:拼接目标音频和混合音频
21
  def combine_audio_with_time(target_audio, mixed_audio):
22
+ if pipeline is None:
23
+ return "错误: 模型未初始化"
24
+
25
  # 加载目标说话人的样本音频
26
  target_audio_segment = AudioSegment.from_wav(target_audio.name)
27
 
 
39
 
40
  return "final_output.wav", target_start_time
41
 
42
+ # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
43
+ def diarize_audio(temp_file):
44
+ if pipeline is None:
45
+ return "错误: 模型未初始化"
46
+
47
+ try:
48
+ diarization = pipeline(temp_file)
49
+ except Exception as e:
50
+ return f"处理音频时出错: {e}"
51
+
52
+ # 返回 diarization 输出
53
+ return str(diarization)
54
+
55
+ # 处理并生成标签文件
56
+ def generate_labels_from_diarization(diarization_output):
57
+ labels_path = 'labels.txt'
58
+ successful_lines = 0
59
+
60
+ try:
61
+ with open(labels_path, 'w') as outfile:
62
+ lines = diarization_output.strip().split('\n')
63
+ for line in lines:
64
+ try:
65
+ parts = line.strip()[1:-1].split(' --> ')
66
+ start_time = parts[0].strip()
67
+ end_time = parts[1].split(']')[0].strip()
68
+ label = line.split()[-1].strip()
69
+ start_seconds = timestamp_to_seconds(start_time)
70
+ end_seconds = timestamp_to_seconds(end_time)
71
+ outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n")
72
+ successful_lines += 1
73
+ except Exception as e:
74
+ print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
75
+ print(f"成功处理了 {successful_lines} 行。")
76
+ return labels_path if successful_lines > 0 else None
77
+ except Exception as e:
78
+ print(f"写入文件时出错: {e}")
79
+ return None
80
+
81
+ # 将时间戳转换为秒
82
+ def timestamp_to_seconds(timestamp):
83
+ try:
84
+ h, m, s = map(float, timestamp.split(':'))
85
+ return 3600 * h + 60 * m + s
86
+ except ValueError as e:
87
+ print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
88
+ return None
89
+
90
+ @spaces.GPU(duration=60 * 2)
91
+ # 处理音频文件
92
+ def process_audio(audio):
93
+ diarization_result = diarize_audio(save_audio(audio))
94
+ if diarization_result.startswith("错误"):
95
+ return diarization_result, None # 如果出错,返回错误信息和空的标签文件
96
+ else:
97
+ label_file = generate_labels_from_diarization(diarization_result)
98
+ return diarization_result, label_file
99
+
100
+ # 保存上传的音频
101
+ def save_audio(audio):
102
+ with open(audio.name, "rb") as f:
103
+ audio_data = f.read()
104
+
105
+ # 保存上传的音频文件到临时位置
106
+ with open("temp.wav", "wb") as f:
107
+ f.write(audio_data)
108
+
109
+ return "temp.wav"
110
+
111
  # Gradio 接口
112
+ with gr.Blocks() as demo:
113
+ gr.Markdown("""
114
+ # 🗣️ 音频拼接与说话人分类 🗣️
115
+ 上传目标说话人音频和混合音频,拼接并进行说话人分类。
116
+ """)
117
+
118
+ audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
119
+ mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
120
+
121
+ process_button = gr.Button("处理音频")
122
+ diarization_output = gr.Textbox(label="说话人分离结果")
123
+ label_file_link = gr.File(label="下载标签文件")
124
+
125
+ # 处理音频
126
+ process_button.click(
127
+ fn=process_audio,
128
+ inputs=[audio_input],
129
+ outputs=[diarization_output, label_file_link]
130
+ )
131
+
132
+ demo.launch(share=False)