Yusen commited on
Commit
68e03b2
·
1 Parent(s): 8503f66

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +355 -0
utils.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import sys
5
+ import argparse
6
+ import logging
7
+ import json
8
+ import subprocess
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import torchaudio
13
+ from scipy.io.wavfile import read
14
+ import torch
15
+ import torchvision
16
+ from torch.nn import functional as F
17
+ from commons import sequence_mask
18
+ from hubert import hubert_model
19
+ MATPLOTLIB_FLAG = False
20
+
21
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
22
+ logger = logging
23
+
24
+ f0_bin = 256
25
+ f0_max = 1100.0
26
+ f0_min = 50.0
27
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
28
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
29
+
30
+ def f0_to_coarse(f0):
31
+ is_torch = isinstance(f0, torch.Tensor)
32
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
33
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
34
+
35
+ f0_mel[f0_mel <= 1] = 1
36
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
37
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
38
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
39
+ return f0_coarse
40
+
41
+
42
+ def get_hubert_model(rank=None):
43
+
44
+ hubert_soft = hubert_model.hubert_soft("hubert/hubert-soft-0d54a1f4.pt")
45
+ if rank is not None:
46
+ hubert_soft = hubert_soft.cuda(rank)
47
+ return hubert_soft
48
+
49
+ def get_hubert_content(hmodel, y=None, path=None):
50
+ if path is not None:
51
+ source, sr = torchaudio.load(path)
52
+ source = torchaudio.functional.resample(source, sr, 16000)
53
+ if len(source.shape) == 2 and source.shape[1] >= 2:
54
+ source = torch.mean(source, dim=0).unsqueeze(0)
55
+ else:
56
+ source = y
57
+ source = source.unsqueeze(0)
58
+ with torch.inference_mode():
59
+ units = hmodel.units(source)
60
+ return units.transpose(1,2)
61
+
62
+
63
+ def get_content(cmodel, y):
64
+ with torch.no_grad():
65
+ c = cmodel.extract_features(y.squeeze(1))[0]
66
+ c = c.transpose(1, 2)
67
+ return c
68
+
69
+
70
+
71
+ def transform(mel, height): # 68-92
72
+ #r = np.random.random()
73
+ #rate = r * 0.3 + 0.85 # 0.85-1.15
74
+ #height = int(mel.size(-2) * rate)
75
+ tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
76
+ if height >= mel.size(-2):
77
+ return tgt[:, :mel.size(-2), :]
78
+ else:
79
+ silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
80
+ silence += torch.randn_like(silence) / 10
81
+ return torch.cat((tgt, silence), 1)
82
+
83
+
84
+ def stretch(mel, width): # 0.5-2
85
+ return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
86
+
87
+
88
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
89
+ assert os.path.isfile(checkpoint_path)
90
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
91
+ iteration = checkpoint_dict['iteration']
92
+ learning_rate = checkpoint_dict['learning_rate']
93
+ if iteration is None:
94
+ iteration = 1
95
+ if learning_rate is None:
96
+ learning_rate = 0.0002
97
+ if optimizer is not None and checkpoint_dict['optimizer'] is not None:
98
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
99
+ saved_state_dict = checkpoint_dict['model']
100
+ if hasattr(model, 'module'):
101
+ state_dict = model.module.state_dict()
102
+ else:
103
+ state_dict = model.state_dict()
104
+ new_state_dict= {}
105
+ for k, v in state_dict.items():
106
+ try:
107
+ new_state_dict[k] = saved_state_dict[k]
108
+ except:
109
+ logger.info("%s is not in the checkpoint" % k)
110
+ new_state_dict[k] = v
111
+ if hasattr(model, 'module'):
112
+ model.module.load_state_dict(new_state_dict)
113
+ else:
114
+ model.load_state_dict(new_state_dict)
115
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
116
+ checkpoint_path, iteration))
117
+ return model, optimizer, learning_rate, iteration
118
+
119
+
120
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
121
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
122
+ iteration, checkpoint_path))
123
+ if hasattr(model, 'module'):
124
+ state_dict = model.module.state_dict()
125
+ else:
126
+ state_dict = model.state_dict()
127
+ torch.save({'model': state_dict,
128
+ 'iteration': iteration,
129
+ 'optimizer': optimizer.state_dict(),
130
+ 'learning_rate': learning_rate}, checkpoint_path)
131
+ clean_ckpt = False
132
+ if clean_ckpt:
133
+ clean_checkpoints(path_to_models='logs/32k/', n_ckpts_to_keep=3, sort_by_time=True)
134
+
135
+ def clean_checkpoints(path_to_models='logs/48k/', n_ckpts_to_keep=2, sort_by_time=True):
136
+ """Freeing up space by deleting saved ckpts
137
+
138
+ Arguments:
139
+ path_to_models -- Path to the model directory
140
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
141
+ sort_by_time -- True -> chronologically delete ckpts
142
+ False -> lexicographically delete ckpts
143
+ """
144
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
145
+ name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
146
+ time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
147
+ sort_key = time_key if sort_by_time else name_key
148
+ x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
149
+ to_del = [os.path.join(path_to_models, fn) for fn in
150
+ (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
151
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
152
+ del_routine = lambda x: [os.remove(x), del_info(x)]
153
+ rs = [del_routine(fn) for fn in to_del]
154
+
155
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
156
+ for k, v in scalars.items():
157
+ writer.add_scalar(k, v, global_step)
158
+ for k, v in histograms.items():
159
+ writer.add_histogram(k, v, global_step)
160
+ for k, v in images.items():
161
+ writer.add_image(k, v, global_step, dataformats='HWC')
162
+ for k, v in audios.items():
163
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
164
+
165
+
166
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
167
+ f_list = glob.glob(os.path.join(dir_path, regex))
168
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
169
+ x = f_list[-1]
170
+ print(x)
171
+ return x
172
+
173
+
174
+ def plot_spectrogram_to_numpy(spectrogram):
175
+ global MATPLOTLIB_FLAG
176
+ if not MATPLOTLIB_FLAG:
177
+ import matplotlib
178
+ matplotlib.use("Agg")
179
+ MATPLOTLIB_FLAG = True
180
+ mpl_logger = logging.getLogger('matplotlib')
181
+ mpl_logger.setLevel(logging.WARNING)
182
+ import matplotlib.pylab as plt
183
+ import numpy as np
184
+
185
+ fig, ax = plt.subplots(figsize=(10,2))
186
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
187
+ interpolation='none')
188
+ plt.colorbar(im, ax=ax)
189
+ plt.xlabel("Frames")
190
+ plt.ylabel("Channels")
191
+ plt.tight_layout()
192
+
193
+ fig.canvas.draw()
194
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
195
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
196
+ plt.close()
197
+ return data
198
+
199
+
200
+ def plot_alignment_to_numpy(alignment, info=None):
201
+ global MATPLOTLIB_FLAG
202
+ if not MATPLOTLIB_FLAG:
203
+ import matplotlib
204
+ matplotlib.use("Agg")
205
+ MATPLOTLIB_FLAG = True
206
+ mpl_logger = logging.getLogger('matplotlib')
207
+ mpl_logger.setLevel(logging.WARNING)
208
+ import matplotlib.pylab as plt
209
+ import numpy as np
210
+
211
+ fig, ax = plt.subplots(figsize=(6, 4))
212
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
213
+ interpolation='none')
214
+ fig.colorbar(im, ax=ax)
215
+ xlabel = 'Decoder timestep'
216
+ if info is not None:
217
+ xlabel += '\n\n' + info
218
+ plt.xlabel(xlabel)
219
+ plt.ylabel('Encoder timestep')
220
+ plt.tight_layout()
221
+
222
+ fig.canvas.draw()
223
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
224
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
225
+ plt.close()
226
+ return data
227
+
228
+
229
+ def load_wav_to_torch(full_path):
230
+ sampling_rate, data = read(full_path)
231
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
232
+
233
+
234
+ def load_filepaths_and_text(filename, split="|"):
235
+ with open(filename, encoding='utf-8') as f:
236
+ filepaths_and_text = [line.strip().split(split) for line in f]
237
+ return filepaths_and_text
238
+
239
+
240
+ def get_hparams(init=True):
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
243
+ help='JSON file for configuration')
244
+ parser.add_argument('-m', '--model', type=str, required=True,
245
+ help='Model name')
246
+
247
+ args = parser.parse_args()
248
+ model_dir = os.path.join("./logs", args.model)
249
+
250
+ if not os.path.exists(model_dir):
251
+ os.makedirs(model_dir)
252
+
253
+ config_path = args.config
254
+ config_save_path = os.path.join(model_dir, "config.json")
255
+ if init:
256
+ with open(config_path, "r") as f:
257
+ data = f.read()
258
+ with open(config_save_path, "w") as f:
259
+ f.write(data)
260
+ else:
261
+ with open(config_save_path, "r") as f:
262
+ data = f.read()
263
+ config = json.loads(data)
264
+
265
+ hparams = HParams(**config)
266
+ hparams.model_dir = model_dir
267
+ return hparams
268
+
269
+
270
+ def get_hparams_from_dir(model_dir):
271
+ config_save_path = os.path.join(model_dir, "config.json")
272
+ with open(config_save_path, "r") as f:
273
+ data = f.read()
274
+ config = json.loads(data)
275
+
276
+ hparams =HParams(**config)
277
+ hparams.model_dir = model_dir
278
+ return hparams
279
+
280
+
281
+ def get_hparams_from_file(config_path):
282
+ with open(config_path, "r") as f:
283
+ data = f.read()
284
+ config = json.loads(data)
285
+
286
+ hparams =HParams(**config)
287
+ return hparams
288
+
289
+
290
+ def check_git_hash(model_dir):
291
+ source_dir = os.path.dirname(os.path.realpath(__file__))
292
+ if not os.path.exists(os.path.join(source_dir, ".git")):
293
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
294
+ source_dir
295
+ ))
296
+ return
297
+
298
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
299
+
300
+ path = os.path.join(model_dir, "githash")
301
+ if os.path.exists(path):
302
+ saved_hash = open(path).read()
303
+ if saved_hash != cur_hash:
304
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
305
+ saved_hash[:8], cur_hash[:8]))
306
+ else:
307
+ open(path, "w").write(cur_hash)
308
+
309
+
310
+ def get_logger(model_dir, filename="train.log"):
311
+ global logger
312
+ logger = logging.getLogger(os.path.basename(model_dir))
313
+ logger.setLevel(logging.DEBUG)
314
+
315
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
316
+ if not os.path.exists(model_dir):
317
+ os.makedirs(model_dir)
318
+ h = logging.FileHandler(os.path.join(model_dir, filename))
319
+ h.setLevel(logging.DEBUG)
320
+ h.setFormatter(formatter)
321
+ logger.addHandler(h)
322
+ return logger
323
+
324
+
325
+ class HParams():
326
+ def __init__(self, **kwargs):
327
+ for k, v in kwargs.items():
328
+ if type(v) == dict:
329
+ v = HParams(**v)
330
+ self[k] = v
331
+
332
+ def keys(self):
333
+ return self.__dict__.keys()
334
+
335
+ def items(self):
336
+ return self.__dict__.items()
337
+
338
+ def values(self):
339
+ return self.__dict__.values()
340
+
341
+ def __len__(self):
342
+ return len(self.__dict__)
343
+
344
+ def __getitem__(self, key):
345
+ return getattr(self, key)
346
+
347
+ def __setitem__(self, key, value):
348
+ return setattr(self, key, value)
349
+
350
+ def __contains__(self, key):
351
+ return key in self.__dict__
352
+
353
+ def __repr__(self):
354
+ return self.__dict__.__repr__()
355
+