kevinwang676 commited on
Commit
da9d371
1 Parent(s): 13ce559

Upload 3 files

Browse files
extract_embedding.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import torch
17
+ import torchaudio
18
+ from tqdm import tqdm
19
+ import onnxruntime
20
+ import torchaudio.compliance.kaldi as kaldi
21
+
22
+
23
+ def main(args):
24
+ utt2wav, utt2spk = {}, {}
25
+ with open('{}/wav.scp'.format(args.dir)) as f:
26
+ for l in f:
27
+ l = l.replace('\n', '').split()
28
+ utt2wav[l[0]] = l[1]
29
+ with open('{}/utt2spk'.format(args.dir)) as f:
30
+ for l in f:
31
+ l = l.replace('\n', '').split()
32
+ utt2spk[l[0]] = l[1]
33
+
34
+ option = onnxruntime.SessionOptions()
35
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
36
+ option.intra_op_num_threads = 1
37
+ providers = ["CPUExecutionProvider"]
38
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
39
+
40
+ utt2embedding, spk2embedding = {}, {}
41
+ for utt in tqdm(utt2wav.keys()):
42
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
43
+ if sample_rate != 16000:
44
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
45
+ feat = kaldi.fbank(audio,
46
+ num_mel_bins=80,
47
+ dither=0,
48
+ sample_frequency=16000)
49
+ feat = feat - feat.mean(dim=0, keepdim=True)
50
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
51
+ utt2embedding[utt] = embedding
52
+ spk = utt2spk[utt]
53
+ if spk not in spk2embedding:
54
+ spk2embedding[spk] = []
55
+ spk2embedding[spk].append(embedding)
56
+ for k, v in spk2embedding.items():
57
+ spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
58
+
59
+ torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
+ torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
+
62
+ if __name__ == "__main__":
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument('--dir',
65
+ type=str)
66
+ parser.add_argument('--onnx_path',
67
+ type=str)
68
+ args = parser.parse_args()
69
+ main(args)
extract_speech_token.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import logging
17
+ import torch
18
+ from tqdm import tqdm
19
+ import onnxruntime
20
+ import numpy as np
21
+ import torchaudio
22
+ import whisper
23
+
24
+
25
+ def main(args):
26
+ utt2wav = {}
27
+ with open('{}/wav.scp'.format(args.dir)) as f:
28
+ for l in f:
29
+ l = l.replace('\n', '').split()
30
+ utt2wav[l[0]] = l[1]
31
+
32
+ option = onnxruntime.SessionOptions()
33
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
34
+ option.intra_op_num_threads = 1
35
+ providers = ["CUDAExecutionProvider"]
36
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
37
+
38
+ utt2speech_token = {}
39
+ for utt in tqdm(utt2wav.keys()):
40
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
41
+ if sample_rate != 16000:
42
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
43
+ if audio.shape[1] / 16000 > 30:
44
+ logging.warning('do not support extract speech token for audio longer than 30s')
45
+ speech_token = []
46
+ else:
47
+ feat = whisper.log_mel_spectrogram(audio, n_mels=128)
48
+ speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
49
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
50
+ utt2speech_token[utt] = speech_token
51
+ torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('--dir',
57
+ type=str)
58
+ parser.add_argument('--onnx_path',
59
+ type=str)
60
+ args = parser.parse_args()
61
+ main(args)
make_parquet_list.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import logging
17
+ import os
18
+ import json
19
+ from tqdm import tqdm
20
+ import pandas as pd
21
+ import multiprocessing
22
+ import time
23
+ import torch
24
+
25
+
26
+ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
27
+ start_time = time.time()
28
+ data_list = []
29
+ for utt in tqdm(utt_list):
30
+ data = open(utt2wav[utt], 'rb').read()
31
+ data_list.append(data)
32
+ wav_list = [utt2wav[utt] for utt in utt_list]
33
+ text_list = [utt2text[utt] for utt in utt_list]
34
+ spk_list = [utt2spk[utt] for utt in utt_list]
35
+ uttembedding_list = [utt2embedding[utt] for utt in utt_list]
36
+ spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
37
+ speech_token_list = [utt2speech_token[utt] for utt in utt_list]
38
+
39
+ # 保存到parquet,utt2parquet_file,spk2parquet_file
40
+ df = pd.DataFrame()
41
+ df['utt'] = utt_list
42
+ df['wav'] = wav_list
43
+ df['audio_data'] = data_list
44
+ df['text'] = text_list
45
+ df['spk'] = spk_list
46
+ df['utt_embedding'] = uttembedding_list
47
+ df['spk_embedding'] = spkembedding_list
48
+ df['speech_token'] = speech_token_list
49
+ df.to_parquet(parquet_file)
50
+ with open(utt2parquet_file, 'w') as f:
51
+ json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
52
+ with open(spk2parquet_file, 'w') as f:
53
+ json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54
+ logging.info('spend time {}'.format(time.time() - start_time))
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument('--num_utts_per_parquet',
59
+ type=int,
60
+ default=1000,
61
+ help='num utts per parquet')
62
+ parser.add_argument('--num_processes',
63
+ type=int,
64
+ default=1,
65
+ help='num processes for make parquets')
66
+ parser.add_argument('--src_dir',
67
+ type=str)
68
+ parser.add_argument('--des_dir',
69
+ type=str)
70
+ args = parser.parse_args()
71
+
72
+ utt2wav, utt2text, utt2spk = {}, {}, {}
73
+ with open('{}/wav.scp'.format(args.src_dir)) as f:
74
+ for l in f:
75
+ l = l.replace('\n', '').split()
76
+ utt2wav[l[0]] = l[1]
77
+ with open('{}/text'.format(args.src_dir)) as f:
78
+ for l in f:
79
+ l = l.replace('\n', '').split()
80
+ utt2text[l[0]] = ' '.join(l[1:])
81
+ with open('{}/utt2spk'.format(args.src_dir)) as f:
82
+ for l in f:
83
+ l = l.replace('\n', '').split()
84
+ utt2spk[l[0]] = l[1]
85
+ utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
86
+ spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
87
+ utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
88
+ utts = list(utt2wav.keys())
89
+
90
+ # Using process pool to speedup
91
+ pool = multiprocessing.Pool(processes=args.num_processes)
92
+ parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
93
+ for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
94
+ parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
95
+ utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
96
+ spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
97
+ parquet_list.append(parquet_file)
98
+ utt2parquet_list.append(utt2parquet_file)
99
+ spk2parquet_list.append(spk2parquet_file)
100
+ pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
101
+ pool.close()
102
+ pool.join()
103
+
104
+ with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
105
+ open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
106
+ open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
107
+ for name in parquet_list:
108
+ f1.write(name + '\n')
109
+ for name in utt2parquet_list:
110
+ f2.write(name + '\n')
111
+ for name in spk2parquet_list:
112
+ f3.write(name + '\n')