kevinwang676 commited on
Commit
13ce559
1 Parent(s): 9cd1263

Delete tools

Browse files
tools/extract_embedding.py DELETED
@@ -1,69 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import argparse
16
- import torch
17
- import torchaudio
18
- from tqdm import tqdm
19
- import onnxruntime
20
- import torchaudio.compliance.kaldi as kaldi
21
-
22
-
23
- def main(args):
24
- utt2wav, utt2spk = {}, {}
25
- with open('{}/wav.scp'.format(args.dir)) as f:
26
- for l in f:
27
- l = l.replace('\n', '').split()
28
- utt2wav[l[0]] = l[1]
29
- with open('{}/utt2spk'.format(args.dir)) as f:
30
- for l in f:
31
- l = l.replace('\n', '').split()
32
- utt2spk[l[0]] = l[1]
33
-
34
- option = onnxruntime.SessionOptions()
35
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
36
- option.intra_op_num_threads = 1
37
- providers = ["CPUExecutionProvider"]
38
- ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
39
-
40
- utt2embedding, spk2embedding = {}, {}
41
- for utt in tqdm(utt2wav.keys()):
42
- audio, sample_rate = torchaudio.load(utt2wav[utt])
43
- if sample_rate != 16000:
44
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
45
- feat = kaldi.fbank(audio,
46
- num_mel_bins=80,
47
- dither=0,
48
- sample_frequency=16000)
49
- feat = feat - feat.mean(dim=0, keepdim=True)
50
- embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
51
- utt2embedding[utt] = embedding
52
- spk = utt2spk[utt]
53
- if spk not in spk2embedding:
54
- spk2embedding[spk] = []
55
- spk2embedding[spk].append(embedding)
56
- for k, v in spk2embedding.items():
57
- spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
58
-
59
- torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
- torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
-
62
- if __name__ == "__main__":
63
- parser = argparse.ArgumentParser()
64
- parser.add_argument('--dir',
65
- type=str)
66
- parser.add_argument('--onnx_path',
67
- type=str)
68
- args = parser.parse_args()
69
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/extract_speech_token.py DELETED
@@ -1,61 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import argparse
16
- import logging
17
- import torch
18
- from tqdm import tqdm
19
- import onnxruntime
20
- import numpy as np
21
- import torchaudio
22
- import whisper
23
-
24
-
25
- def main(args):
26
- utt2wav = {}
27
- with open('{}/wav.scp'.format(args.dir)) as f:
28
- for l in f:
29
- l = l.replace('\n', '').split()
30
- utt2wav[l[0]] = l[1]
31
-
32
- option = onnxruntime.SessionOptions()
33
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
34
- option.intra_op_num_threads = 1
35
- providers = ["CUDAExecutionProvider"]
36
- ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
37
-
38
- utt2speech_token = {}
39
- for utt in tqdm(utt2wav.keys()):
40
- audio, sample_rate = torchaudio.load(utt2wav[utt])
41
- if sample_rate != 16000:
42
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
43
- if audio.shape[1] / 16000 > 30:
44
- logging.warning('do not support extract speech token for audio longer than 30s')
45
- speech_token = []
46
- else:
47
- feat = whisper.log_mel_spectrogram(audio, n_mels=128)
48
- speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
49
- ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
50
- utt2speech_token[utt] = speech_token
51
- torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
52
-
53
-
54
- if __name__ == "__main__":
55
- parser = argparse.ArgumentParser()
56
- parser.add_argument('--dir',
57
- type=str)
58
- parser.add_argument('--onnx_path',
59
- type=str)
60
- args = parser.parse_args()
61
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/make_parquet_list.py DELETED
@@ -1,112 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import argparse
16
- import logging
17
- import os
18
- import json
19
- from tqdm import tqdm
20
- import pandas as pd
21
- import multiprocessing
22
- import time
23
- import torch
24
-
25
-
26
- def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
27
- start_time = time.time()
28
- data_list = []
29
- for utt in tqdm(utt_list):
30
- data = open(utt2wav[utt], 'rb').read()
31
- data_list.append(data)
32
- wav_list = [utt2wav[utt] for utt in utt_list]
33
- text_list = [utt2text[utt] for utt in utt_list]
34
- spk_list = [utt2spk[utt] for utt in utt_list]
35
- uttembedding_list = [utt2embedding[utt] for utt in utt_list]
36
- spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
37
- speech_token_list = [utt2speech_token[utt] for utt in utt_list]
38
-
39
- # 保存到parquet,utt2parquet_file,spk2parquet_file
40
- df = pd.DataFrame()
41
- df['utt'] = utt_list
42
- df['wav'] = wav_list
43
- df['audio_data'] = data_list
44
- df['text'] = text_list
45
- df['spk'] = spk_list
46
- df['utt_embedding'] = uttembedding_list
47
- df['spk_embedding'] = spkembedding_list
48
- df['speech_token'] = speech_token_list
49
- df.to_parquet(parquet_file)
50
- with open(utt2parquet_file, 'w') as f:
51
- json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
52
- with open(spk2parquet_file, 'w') as f:
53
- json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54
- logging.info('spend time {}'.format(time.time() - start_time))
55
-
56
- if __name__ == "__main__":
57
- parser = argparse.ArgumentParser()
58
- parser.add_argument('--num_utts_per_parquet',
59
- type=int,
60
- default=1000,
61
- help='num utts per parquet')
62
- parser.add_argument('--num_processes',
63
- type=int,
64
- default=1,
65
- help='num processes for make parquets')
66
- parser.add_argument('--src_dir',
67
- type=str)
68
- parser.add_argument('--des_dir',
69
- type=str)
70
- args = parser.parse_args()
71
-
72
- utt2wav, utt2text, utt2spk = {}, {}, {}
73
- with open('{}/wav.scp'.format(args.src_dir)) as f:
74
- for l in f:
75
- l = l.replace('\n', '').split()
76
- utt2wav[l[0]] = l[1]
77
- with open('{}/text'.format(args.src_dir)) as f:
78
- for l in f:
79
- l = l.replace('\n', '').split()
80
- utt2text[l[0]] = ' '.join(l[1:])
81
- with open('{}/utt2spk'.format(args.src_dir)) as f:
82
- for l in f:
83
- l = l.replace('\n', '').split()
84
- utt2spk[l[0]] = l[1]
85
- utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
86
- spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
87
- utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
88
- utts = list(utt2wav.keys())
89
-
90
- # Using process pool to speedup
91
- pool = multiprocessing.Pool(processes=args.num_processes)
92
- parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
93
- for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
94
- parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
95
- utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
96
- spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
97
- parquet_list.append(parquet_file)
98
- utt2parquet_list.append(utt2parquet_file)
99
- spk2parquet_list.append(spk2parquet_file)
100
- pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
101
- pool.close()
102
- pool.join()
103
-
104
- with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
105
- open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
106
- open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
107
- for name in parquet_list:
108
- f1.write(name + '\n')
109
- for name in utt2parquet_list:
110
- f2.write(name + '\n')
111
- for name in spk2parquet_list:
112
- f3.write(name + '\n')