Text-to-Speech
English
geneing commited on
Commit
b8db573
·
1 Parent(s): eb932f8

Enabled onnx conversion

Browse files
Files changed (6) hide show
  1. models.py +15 -16
  2. models_onnx.py +616 -0
  3. onnx_export.py +120 -0
  4. plbert.py +3 -2
  5. test.ipynb +0 -0
  6. test.py +2 -1
models.py CHANGED
@@ -1,4 +1,5 @@
1
  # https://github.com/yl4579/StyleTTS2/blob/main/models.py
 
2
  from istftnet import Decoder
3
  from munch import Munch
4
  from pathlib import Path
@@ -299,7 +300,7 @@ class TextEncoder(nn.Module):
299
 
300
  x = x.transpose(1, 2) # [B, T, chn]
301
 
302
- self.lstm.flatten_parameters()
303
  x, _ = self.lstm(x)
304
 
305
  x = x.transpose(-1, -2)
@@ -404,6 +405,7 @@ class AdaLayerNorm(nn.Module):
404
  x = (1 + gamma) * x + beta
405
  return x.transpose(1, -1).transpose(-1, -2)
406
 
 
407
  class ProsodyPredictor(nn.Module):
408
 
409
  def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
@@ -461,21 +463,17 @@ class ProsodyPredictor(nn.Module):
461
 
462
  return duration.squeeze(-1), en
463
 
464
- def F0Ntrain(self, x, s):
 
465
  x1 = x.transpose(-1, -2)
466
- torch._check(x1.dim() == 3, lambda: print(f"Expected 3D tensor, got {x1.dim()}D tensor"))
467
- torch._check(x1.shape[1] > 1, lambda: print(f"Shape 2, got {x1.size(1)}"))
468
- torch._check(x1.shape[2] > 1, lambda: print(f"Shape 2, got {x1.size(2)}"))
469
- torch._check(x.shape[2] > 0, lambda: print(f"Shape 2, got {x.size(2)}"))
470
- x, _ = self.shared(x1)
471
- # torch._check(x.shape[2] > 0, lambda: print(f"Shape 2, got {x.size(2)}"))
472
-
473
- F0 = x.transpose(-1, -2)
474
  for block in self.F0:
475
  F0 = block(F0, s)
476
  F0 = self.F0_proj(F0)
477
 
478
- N = x.transpose(-1, -2)
479
  for block in self.N:
480
  N = block(N, s)
481
  N = self.N_proj(N)
@@ -511,7 +509,7 @@ class DurationEncoder(nn.Module):
511
 
512
  x = x.permute(2, 0, 1)
513
  s = style.expand(x.shape[0], x.shape[1], -1)
514
- x = torch.cat([x, s], axis=-1)
515
  x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
516
 
517
  x = x.transpose(0, 1)
@@ -520,7 +518,7 @@ class DurationEncoder(nn.Module):
520
  for block in self.lstms:
521
  if isinstance(block, AdaLayerNorm):
522
  x = block(x.transpose(-1, -2), style).transpose(-1, -2)
523
- x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
524
  x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
525
  else:
526
  x = x.transpose(-1, -2)
@@ -553,11 +551,11 @@ class DurationEncoder(nn.Module):
553
  for block in self.lstms:
554
  if isinstance(block, AdaLayerNorm):
555
  x = block(x.transpose(-1, -2), style).transpose(-1, -2)
556
- x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
557
  else:
558
  x = x.transpose(-1, -2)
559
 
560
- block.flatten_parameters()
561
  x, _ = block(x)
562
 
563
  x = F.dropout(x, p=self.dropout, training=self.training)
@@ -578,7 +576,8 @@ def recursive_munch(d):
578
  else:
579
  return d
580
 
581
- def build_model(path, device):
 
582
  config = Path(__file__).parent / 'config.json'
583
  assert config.exists(), f'Config path incorrect: config.json not found at {config}'
584
  with open(config, 'r') as r:
 
1
  # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from ast import Tuple
3
  from istftnet import Decoder
4
  from munch import Munch
5
  from pathlib import Path
 
300
 
301
  x = x.transpose(1, 2) # [B, T, chn]
302
 
303
+ # self.lstm.flatten_parameters()
304
  x, _ = self.lstm(x)
305
 
306
  x = x.transpose(-1, -2)
 
405
  x = (1 + gamma) * x + beta
406
  return x.transpose(1, -1).transpose(-1, -2)
407
 
408
+
409
  class ProsodyPredictor(nn.Module):
410
 
411
  def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
 
463
 
464
  return duration.squeeze(-1), en
465
 
466
+
467
+ def F0Ntrain(self, x: torch.Tensor, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
468
  x1 = x.transpose(-1, -2)
469
+ x2, _temp = self.shared(x1)
470
+
471
+ F0 = x2.transpose(-1, -2)
 
 
 
 
 
472
  for block in self.F0:
473
  F0 = block(F0, s)
474
  F0 = self.F0_proj(F0)
475
 
476
+ N = x2.transpose(-1, -2)
477
  for block in self.N:
478
  N = block(N, s)
479
  N = self.N_proj(N)
 
509
 
510
  x = x.permute(2, 0, 1)
511
  s = style.expand(x.shape[0], x.shape[1], -1)
512
+ x = torch.cat([x, s], dim=-1)
513
  x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
514
 
515
  x = x.transpose(0, 1)
 
518
  for block in self.lstms:
519
  if isinstance(block, AdaLayerNorm):
520
  x = block(x.transpose(-1, -2), style).transpose(-1, -2)
521
+ x = torch.cat([x, s.permute(1, -1, 0)], dim=1)
522
  x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
523
  else:
524
  x = x.transpose(-1, -2)
 
551
  for block in self.lstms:
552
  if isinstance(block, AdaLayerNorm):
553
  x = block(x.transpose(-1, -2), style).transpose(-1, -2)
554
+ x = torch.cat([x, s.permute(1, -1, 0)], dim=1)
555
  else:
556
  x = x.transpose(-1, -2)
557
 
558
+ # block.flatten_parameters()
559
  x, _ = block(x)
560
 
561
  x = F.dropout(x, p=self.dropout, training=self.training)
 
576
  else:
577
  return d
578
 
579
+
580
+ def build_model(path: str, device: str):
581
  config = Path(__file__).parent / 'config.json'
582
  assert config.exists(), f'Config path incorrect: config.json not found at {config}'
583
  with open(config, 'r') as r:
models_onnx.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from ast import Tuple
3
+ from istftnet import Decoder
4
+ from munch import Munch
5
+ from pathlib import Path
6
+ from plbert import load_plbert
7
+ from torch.nn.utils import weight_norm, spectral_norm
8
+ import json
9
+ import numpy as np
10
+ import os.path as osp
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class LearnedDownSample(nn.Module):
16
+ def __init__(self, layer_type, dim_in):
17
+ super().__init__()
18
+ self.layer_type = layer_type
19
+
20
+ if self.layer_type == 'none':
21
+ self.conv = nn.Identity()
22
+ elif self.layer_type == 'timepreserve':
23
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
24
+ elif self.layer_type == 'half':
25
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
26
+ else:
27
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
28
+
29
+ def forward(self, x):
30
+ return self.conv(x)
31
+
32
+ class LearnedUpSample(nn.Module):
33
+ def __init__(self, layer_type, dim_in):
34
+ super().__init__()
35
+ self.layer_type = layer_type
36
+
37
+ if self.layer_type == 'none':
38
+ self.conv = nn.Identity()
39
+ elif self.layer_type == 'timepreserve':
40
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
41
+ elif self.layer_type == 'half':
42
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
43
+ else:
44
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
45
+
46
+
47
+ def forward(self, x):
48
+ return self.conv(x)
49
+
50
+ class DownSample(nn.Module):
51
+ def __init__(self, layer_type):
52
+ super().__init__()
53
+ self.layer_type = layer_type
54
+
55
+ def forward(self, x):
56
+ if self.layer_type == 'none':
57
+ return x
58
+ elif self.layer_type == 'timepreserve':
59
+ return F.avg_pool2d(x, (2, 1))
60
+ elif self.layer_type == 'half':
61
+ if x.shape[-1] % 2 != 0:
62
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
63
+ return F.avg_pool2d(x, 2)
64
+ else:
65
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
66
+
67
+
68
+ class UpSample(nn.Module):
69
+ def __init__(self, layer_type):
70
+ super().__init__()
71
+ self.layer_type = layer_type
72
+
73
+ def forward(self, x):
74
+ if self.layer_type == 'none':
75
+ return x
76
+ elif self.layer_type == 'timepreserve':
77
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
78
+ elif self.layer_type == 'half':
79
+ return F.interpolate(x, scale_factor=2, mode='nearest')
80
+ else:
81
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
82
+
83
+
84
+ class ResBlk(nn.Module):
85
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
86
+ normalize=False, downsample='none'):
87
+ super().__init__()
88
+ self.actv = actv
89
+ self.normalize = normalize
90
+ self.downsample = DownSample(downsample)
91
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
92
+ self.learned_sc = dim_in != dim_out
93
+ self._build_weights(dim_in, dim_out)
94
+
95
+ def _build_weights(self, dim_in, dim_out):
96
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
97
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
98
+ if self.normalize:
99
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
100
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
101
+ if self.learned_sc:
102
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
103
+
104
+ def _shortcut(self, x):
105
+ if self.learned_sc:
106
+ x = self.conv1x1(x)
107
+ if self.downsample:
108
+ x = self.downsample(x)
109
+ return x
110
+
111
+ def _residual(self, x):
112
+ if self.normalize:
113
+ x = self.norm1(x)
114
+ x = self.actv(x)
115
+ x = self.conv1(x)
116
+ x = self.downsample_res(x)
117
+ if self.normalize:
118
+ x = self.norm2(x)
119
+ x = self.actv(x)
120
+ x = self.conv2(x)
121
+ return x
122
+
123
+ def forward(self, x):
124
+ x = self._shortcut(x) + self._residual(x)
125
+ return x / np.sqrt(2) # unit variance
126
+
127
+ class LinearNorm(torch.nn.Module):
128
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
129
+ super(LinearNorm, self).__init__()
130
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
131
+
132
+ torch.nn.init.xavier_uniform_(
133
+ self.linear_layer.weight,
134
+ gain=torch.nn.init.calculate_gain(w_init_gain))
135
+
136
+ def forward(self, x):
137
+ return self.linear_layer(x)
138
+
139
+ class Discriminator2d(nn.Module):
140
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
141
+ super().__init__()
142
+ blocks = []
143
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
144
+
145
+ for lid in range(repeat_num):
146
+ dim_out = min(dim_in*2, max_conv_dim)
147
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
148
+ dim_in = dim_out
149
+
150
+ blocks += [nn.LeakyReLU(0.2)]
151
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
152
+ blocks += [nn.LeakyReLU(0.2)]
153
+ blocks += [nn.AdaptiveAvgPool2d(1)]
154
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
155
+ self.main = nn.Sequential(*blocks)
156
+
157
+ def get_feature(self, x):
158
+ features = []
159
+ for l in self.main:
160
+ x = l(x)
161
+ features.append(x)
162
+ out = features[-1]
163
+ out = out.view(out.size(0), -1) # (batch, num_domains)
164
+ return out, features
165
+
166
+ def forward(self, x):
167
+ out, features = self.get_feature(x)
168
+ out = out.squeeze() # (batch)
169
+ return out, features
170
+
171
+ class ResBlk1d(nn.Module):
172
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
173
+ normalize=False, downsample='none', dropout_p=0.2):
174
+ super().__init__()
175
+ self.actv = actv
176
+ self.normalize = normalize
177
+ self.downsample_type = downsample
178
+ self.learned_sc = dim_in != dim_out
179
+ self._build_weights(dim_in, dim_out)
180
+ self.dropout_p = dropout_p
181
+
182
+ if self.downsample_type == 'none':
183
+ self.pool = nn.Identity()
184
+ else:
185
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
186
+
187
+ def _build_weights(self, dim_in, dim_out):
188
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
189
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
190
+ if self.normalize:
191
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
192
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
193
+ if self.learned_sc:
194
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
195
+
196
+ def downsample(self, x):
197
+ if self.downsample_type == 'none':
198
+ return x
199
+ else:
200
+ if x.shape[-1] % 2 != 0:
201
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
202
+ return F.avg_pool1d(x, 2)
203
+
204
+ def _shortcut(self, x):
205
+ if self.learned_sc:
206
+ x = self.conv1x1(x)
207
+ x = self.downsample(x)
208
+ return x
209
+
210
+ def _residual(self, x):
211
+ if self.normalize:
212
+ x = self.norm1(x)
213
+ x = self.actv(x)
214
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
215
+
216
+ x = self.conv1(x)
217
+ x = self.pool(x)
218
+ if self.normalize:
219
+ x = self.norm2(x)
220
+
221
+ x = self.actv(x)
222
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
223
+
224
+ x = self.conv2(x)
225
+ return x
226
+
227
+ def forward(self, x):
228
+ x = self._shortcut(x) + self._residual(x)
229
+ return x / np.sqrt(2) # unit variance
230
+
231
+ class LayerNorm(nn.Module):
232
+ def __init__(self, channels, eps=1e-5):
233
+ super().__init__()
234
+ self.channels = channels
235
+ self.eps = eps
236
+
237
+ self.gamma = nn.Parameter(torch.ones(channels))
238
+ self.beta = nn.Parameter(torch.zeros(channels))
239
+
240
+ def forward(self, x):
241
+ x = x.transpose(1, -1)
242
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
243
+ return x.transpose(1, -1)
244
+
245
+ class TextEncoder(nn.Module):
246
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
247
+ super().__init__()
248
+ self.embedding = nn.Embedding(n_symbols, channels)
249
+
250
+ padding = (kernel_size - 1) // 2
251
+ self.cnn = nn.ModuleList()
252
+ for _ in range(depth):
253
+ self.cnn.append(nn.Sequential(
254
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
255
+ LayerNorm(channels),
256
+ actv,
257
+ nn.Dropout(0.2),
258
+ ))
259
+ # self.cnn = nn.Sequential(*self.cnn)
260
+
261
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
262
+
263
+ def forward(self, x, input_lengths, m):
264
+ x = self.embedding(x) # [B, T, emb]
265
+ x = x.transpose(1, 2) # [B, emb, T]
266
+ m = m.to(input_lengths.device).unsqueeze(1)
267
+ x.masked_fill_(m, 0.0)
268
+
269
+ for c in self.cnn:
270
+ x = c(x)
271
+ x.masked_fill_(m, 0.0)
272
+
273
+ x = x.transpose(1, 2) # [B, T, chn]
274
+
275
+ x = nn.utils.rnn.pack_padded_sequence(
276
+ x, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
277
+
278
+ self.lstm.flatten_parameters()
279
+ x, _ = self.lstm(x)
280
+ x, _ = nn.utils.rnn.pad_packed_sequence(
281
+ x, batch_first=True)
282
+
283
+ x = x.transpose(-1, -2)
284
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
285
+
286
+ x_pad[:, :, :x.shape[-1]] = x
287
+ x = x_pad.to(x.device)
288
+
289
+ x.masked_fill_(m, 0.0)
290
+
291
+ return x
292
+
293
+ def inference(self, x):
294
+ x = self.embedding(x) # [B, T, emb]
295
+ x = x.transpose(1, 2) # [B, emb, T]
296
+
297
+ for c in self.cnn:
298
+ x = c(x)
299
+
300
+ x = x.transpose(1, 2) # [B, T, chn]
301
+
302
+ # self.lstm.flatten_parameters()
303
+ x, _ = self.lstm(x)
304
+
305
+ x = x.transpose(-1, -2)
306
+
307
+ return x
308
+
309
+ def length_to_mask(self, lengths):
310
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
311
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
312
+ return mask
313
+
314
+
315
+
316
+ class AdaIN1d(nn.Module):
317
+ def __init__(self, style_dim, num_features):
318
+ super().__init__()
319
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
320
+ self.fc = nn.Linear(style_dim, num_features*2)
321
+
322
+ def forward(self, x, s):
323
+ h = self.fc(s)
324
+ h = h.view(h.size(0), h.size(1), 1)
325
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
326
+ return (1 + gamma) * self.norm(x) + beta
327
+
328
+ class UpSample1d(nn.Module):
329
+ def __init__(self, layer_type):
330
+ super().__init__()
331
+ self.layer_type = layer_type
332
+
333
+ def forward(self, x):
334
+ if self.layer_type == 'none':
335
+ return x
336
+ else:
337
+ return F.interpolate(x, scale_factor=2, mode='nearest')
338
+
339
+ class AdainResBlk1d(nn.Module):
340
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
341
+ upsample='none', dropout_p=0.0):
342
+ super().__init__()
343
+ self.actv = actv
344
+ self.upsample_type = upsample
345
+ self.upsample = UpSample1d(upsample)
346
+ self.learned_sc = dim_in != dim_out
347
+ self._build_weights(dim_in, dim_out, style_dim)
348
+ self.dropout = nn.Dropout(dropout_p)
349
+
350
+ if upsample == 'none':
351
+ self.pool = nn.Identity()
352
+ else:
353
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
354
+
355
+
356
+ def _build_weights(self, dim_in, dim_out, style_dim):
357
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
358
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
359
+ self.norm1 = AdaIN1d(style_dim, dim_in)
360
+ self.norm2 = AdaIN1d(style_dim, dim_out)
361
+ if self.learned_sc:
362
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
363
+
364
+ def _shortcut(self, x):
365
+ x = self.upsample(x)
366
+ if self.learned_sc:
367
+ x = self.conv1x1(x)
368
+ return x
369
+
370
+ def _residual(self, x, s):
371
+ x = self.norm1(x, s)
372
+ x = self.actv(x)
373
+ x = self.pool(x)
374
+ x = self.conv1(self.dropout(x))
375
+ x = self.norm2(x, s)
376
+ x = self.actv(x)
377
+ x = self.conv2(self.dropout(x))
378
+ return x
379
+
380
+ def forward(self, x, s):
381
+ out = self._residual(x, s)
382
+ out = (out + self._shortcut(x)) / np.sqrt(2)
383
+ return out
384
+
385
+ class AdaLayerNorm(nn.Module):
386
+ def __init__(self, style_dim, channels, eps=1e-5):
387
+ super().__init__()
388
+ self.channels = channels
389
+ self.eps = eps
390
+
391
+ self.fc = nn.Linear(style_dim, channels*2)
392
+
393
+ def forward(self, x, s):
394
+ x = x.transpose(-1, -2)
395
+ x = x.transpose(1, -1)
396
+
397
+ h = self.fc(s)
398
+ h = h.view(h.size(0), h.size(1), 1)
399
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
400
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
401
+
402
+
403
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
404
+ x = (1 + gamma) * x + beta
405
+ return x.transpose(1, -1).transpose(-1, -2)
406
+
407
+
408
+ class ProsodyPredictor(nn.Module):
409
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
410
+ super().__init__()
411
+
412
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
413
+ d_model=d_hid,
414
+ nlayers=nlayers,
415
+ dropout=dropout)
416
+
417
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
418
+ self.duration_proj = LinearNorm(d_hid, max_dur)
419
+
420
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
421
+
422
+ self.F0 = nn.ModuleList()
423
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
424
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
425
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
426
+
427
+ self.N = nn.ModuleList()
428
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
429
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
430
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
431
+
432
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
433
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
434
+
435
+
436
+ def forward(self, texts, style, text_lengths, alignment, m):
437
+ d = self.text_encoder(texts, style, text_lengths, m)
438
+
439
+ batch_size = d.shape[0]
440
+ text_size = d.shape[1]
441
+
442
+ # predict duration
443
+ input_lengths = text_lengths
444
+ x = nn.utils.rnn.pack_padded_sequence(
445
+ d, input_lengths, batch_first=True, enforce_sorted=False)
446
+
447
+ m = m.to(text_lengths.device).unsqueeze(1)
448
+
449
+ self.lstm.flatten_parameters()
450
+ x, _ = self.lstm(x)
451
+ x, _ = nn.utils.rnn.pad_packed_sequence(
452
+ x, batch_first=True)
453
+
454
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
455
+
456
+ x_pad[:, :x.shape[1], :] = x
457
+ x = x_pad.to(x.device)
458
+
459
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
460
+
461
+ en = (d.transpose(-1, -2) @ alignment)
462
+
463
+ return duration.squeeze(-1), en
464
+
465
+ def F0Ntrain(self, x: torch.Tensor, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
466
+ x1 = x.transpose(-1, -2)
467
+ x2, _temp = self.shared(x1)
468
+
469
+ F0 = x2.transpose(-1, -2)
470
+ for block in self.F0:
471
+ F0 = block(F0, s)
472
+ F0 = self.F0_proj(F0)
473
+
474
+ N = x2.transpose(-1, -2)
475
+ for block in self.N:
476
+ N = block(N, s)
477
+ N = self.N_proj(N)
478
+
479
+ return F0.squeeze(1), N.squeeze(1)
480
+
481
+ def length_to_mask(self, lengths):
482
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
483
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
484
+ return mask
485
+
486
+ class DurationEncoder(nn.Module):
487
+
488
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
489
+ super().__init__()
490
+ self.lstms = nn.ModuleList()
491
+ for _ in range(nlayers):
492
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
493
+ d_model // 2,
494
+ num_layers=1,
495
+ batch_first=True,
496
+ bidirectional=True,
497
+ dropout=dropout))
498
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
499
+
500
+
501
+ self.dropout = dropout
502
+ self.d_model = d_model
503
+ self.sty_dim = sty_dim
504
+
505
+ def forward(self, x, style, text_lengths, m):
506
+ masks = m.to(text_lengths.device)
507
+
508
+ x = x.permute(2, 0, 1)
509
+ s = style.expand(x.shape[0], x.shape[1], -1)
510
+ x = torch.cat([x, s], dim=-1)
511
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
512
+
513
+ x = x.transpose(0, 1)
514
+ x = x.transpose(-1, -2)
515
+
516
+ for block in self.lstms:
517
+ if isinstance(block, AdaLayerNorm):
518
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
519
+ x = torch.cat([x, s.permute(1, -1, 0)], dim=1)
520
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
521
+ else:
522
+ x = x.transpose(-1, -2)
523
+
524
+ x = nn.utils.rnn.pack_padded_sequence(
525
+ x, text_lengths.cpu(), batch_first=True, enforce_sorted=False)
526
+ block.flatten_parameters()
527
+ x, _ = block(x)
528
+ x, _ = nn.utils.rnn.pad_packed_sequence(
529
+ x, batch_first=True)
530
+ x = F.dropout(x, p=self.dropout, training=self.training)
531
+ x = x.transpose(-1, -2)
532
+
533
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
534
+
535
+ x_pad[:, :, :x.shape[-1]] = x
536
+ x = x_pad.to(x.device)
537
+
538
+ return x.transpose(-1, -2)
539
+
540
+ def inference(self, x: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
541
+
542
+ x = x.permute(2, 0, 1)
543
+ s = style.expand(x.shape[0], x.shape[1], -1)
544
+ x = torch.cat([x, s], axis=-1)
545
+
546
+ x = x.transpose(0, 1)
547
+ x = x.transpose(-1, -2)
548
+
549
+ for block in self.lstms:
550
+ if isinstance(block, AdaLayerNorm):
551
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
552
+ x = torch.cat([x, s.permute(1, -1, 0)], dim=1)
553
+ else:
554
+ x = x.transpose(-1, -2)
555
+
556
+ # block.flatten_parameters()
557
+ x, _ = block(x)
558
+
559
+ x = F.dropout(x, p=self.dropout, training=self.training)
560
+ x = x.transpose(-1, -2)
561
+ return x.transpose(-1, -2)
562
+
563
+ def length_to_mask(self, lengths):
564
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
565
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
566
+ return mask
567
+
568
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
569
+ def recursive_munch(d):
570
+ if isinstance(d, dict):
571
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
572
+ elif isinstance(d, list):
573
+ return [recursive_munch(v) for v in d]
574
+ else:
575
+ return d
576
+
577
+ def build_model(path: str, device: str):
578
+ config = Path(__file__).parent / 'config.json'
579
+ assert config.exists(), f'Config path incorrect: config.json not found at {config}'
580
+ with open(config, 'r') as r:
581
+ args = recursive_munch(json.load(r))
582
+ assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
583
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
584
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
585
+ upsample_rates = args.decoder.upsample_rates,
586
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
587
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
588
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
589
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
590
+
591
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
592
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
593
+ bert = load_plbert()
594
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
595
+
596
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
597
+ for child in parent.children():
598
+ if isinstance(child, nn.RNNBase):
599
+ child.flatten_parameters()
600
+
601
+ model = Munch(
602
+ bert=bert.to(device).eval(),
603
+ bert_encoder=bert_encoder.to(device).eval(),
604
+ predictor=predictor.to(device).eval(),
605
+ decoder=decoder.to(device).eval(),
606
+ text_encoder=text_encoder.to(device).eval(),
607
+ )
608
+
609
+ for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
610
+ assert key in model, key
611
+ try:
612
+ model[key].load_state_dict(state_dict)
613
+ except:
614
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
615
+ model[key].load_state_dict(state_dict, strict=False)
616
+ return model
onnx_export.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['TORCH_LOGS'] = '+dynamic'
3
+ # os.environ['TORCH_LOGS'] = '+export'
4
+ # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']="u0 >= 0"
5
+ # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']="1"
6
+ # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']="u0"
7
+
8
+
9
+ from kokoro import phonemize, tokenize, length_to_mask
10
+ import torch.nn.functional as F
11
+ from models_scripting import build_model
12
+ import torch
13
+ from typing import Dict
14
+
15
+ device = "cpu" #'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ model = build_model('kokoro-v0_19.pth', device)
18
+
19
+ voicepack = torch.load('voices/af.pt', weights_only=True).to(device)
20
+
21
+ speed = 1.
22
+
23
+ text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
24
+
25
+ ps = phonemize(text, "a")
26
+ tokens = tokenize(ps)
27
+
28
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
29
+
30
+ class StyleTTS2(torch.nn.Module):
31
+ def __init__(self, model, voicepack):
32
+ super().__init__()
33
+ # self.model = model
34
+ self.bert = model.bert
35
+ self.bert_encoder = model.bert_encoder
36
+ self.predictor = model.predictor
37
+ self.decoder = model.decoder
38
+ self.text_encoder = model.text_encoder
39
+ self.voicepack = voicepack
40
+
41
+ def forward(self, tokens : torch.Tensor):
42
+ speed = 1.
43
+ # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))
44
+ device = tokens.device
45
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
46
+
47
+ text_mask = length_to_mask(input_lengths).to(device)
48
+ bert_dur = self.bert(tokens)
49
+
50
+ d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
51
+
52
+ ref_s = self.voicepack[tokens.shape[1]]
53
+ s = ref_s[:, 128:]
54
+
55
+ d = self.predictor.text_encoder.inference(d_en, s)
56
+ x, _ = self.predictor.lstm(d)
57
+
58
+ duration = self.predictor.duration_proj(x)
59
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
60
+ pred_dur = torch.round(duration).clamp(min=1).long()
61
+
62
+ c_start = F.pad(pred_dur,(1,0), "constant").cumsum(dim=1)[0,0:-1]
63
+ c_end = c_start + pred_dur[0,:]
64
+
65
+ # torch._check(pred_dur.sum().item()>0, lambda: print(f"Got {pred_dur.sum().item()}"))
66
+ indices = torch.arange(0, pred_dur.sum().item()).long().to(device)
67
+
68
+ pred_aln_trg_list=[]
69
+ for cs, ce in zip(c_start, c_end):
70
+ row = torch.where((indices>=cs) & (indices<ce), 1., 0.)
71
+ pred_aln_trg_list.append(row)
72
+ pred_aln_trg=torch.vstack(pred_aln_trg_list)
73
+
74
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
75
+
76
+ F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
77
+ t_en = self.text_encoder.inference(tokens)
78
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
79
+ return (asr, F0_pred, N_pred, ref_s[:, :128])
80
+ # output = self.model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()
81
+
82
+
83
+ # bert = torch.jit.script(model.bert)
84
+ # bert_encoder = torch.jit.script(model.bert_encoder)
85
+ # predictor = torch.jit.script(model.predictor)
86
+ # text_encoder = torch.jit.script(model.text_encoder)
87
+
88
+ # model["bert"] = torch.jit.trace(model["bert"], (tokens, ))
89
+ # # model["decoder"] = torch.jit.script(model["decoder"])
90
+ # bert_dur = model["bert"](tokens)
91
+ # model["bert_encoder"] = torch.jit.trace(model["bert_encoder"], (bert_dur,))
92
+ # model["predictor"] = torch.compile(model["predictor"], backend=backend)
93
+ # model["text_encoder"] = torch.compile(model["text_encoder"], backend=backend)
94
+
95
+ style_model = StyleTTS2(model=model, voicepack=voicepack)
96
+ style_model.eval()
97
+ # style_model = torch.jit.trace_module(style_model.eval(), inputs={'forward': (tokens, )})
98
+ # style_model.model["predictor"].F0Ntrain = torch.jit.script(style_model.model["predictor"].F0Ntrain)
99
+ (asr, F0_pred, N_pred, ref_s) = style_model(tokens)
100
+ print(asr.shape, F0_pred.shape, N_pred.shape, ref_s.shape)
101
+
102
+ # scripted_style_model = torch.jit.script(style_model)
103
+
104
+ # (asr, F0_pred, N_pred, ref_s) = scripted_style_model(tokens)
105
+ # print(asr.shape, F0_pred.shape, N_pred.shape, ref_s.shape)
106
+
107
+ # torch.onnx.export(scripted_style_model, ( tokens, ), "style_model.onnx", verbose=True, opset_version=17, input_names=["tokens"], output_names=["asr", "F0_pred", "N_pred", "ref_s"])
108
+ # token_len = torch.export.Dim("token_len", min=2, max=510)
109
+ # batch = torch.export.Dim("batch")
110
+ # dynamic_shapes = {"tokens":{ 1:token_len}}
111
+ dynamic_shapes = {"tokens":{ 1:"token_len"}}
112
+ print(f"{tokens.shape=}")
113
+ torch.onnx.export(model=style_model, args=( tokens, ), dynamic_axes=dynamic_shapes, input_names=["tokens"], f="style_model.onnx",
114
+ output_names=["asr", "F0_pred", "N_pred", "ref_s"], opset_version=13, verbose=False, dynamo=False)
115
+
116
+
117
+ # with torch.no_grad():
118
+ # torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)
119
+
120
+ # export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)
plbert.py CHANGED
@@ -1,10 +1,11 @@
1
  # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
 
2
  from transformers import AlbertConfig, AlbertModel
3
 
4
  class CustomAlbert(AlbertModel):
5
- def forward(self, *args, **kwargs):
6
  # Call the original forward method
7
- outputs = super().forward(*args, **kwargs)
8
  # Only return the last_hidden_state
9
  return outputs.last_hidden_state
10
 
 
1
  # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2
+ import torch
3
  from transformers import AlbertConfig, AlbertModel
4
 
5
  class CustomAlbert(AlbertModel):
6
+ def forward(self, tokens: torch.Tensor):
7
  # Call the original forward method
8
+ outputs = super().forward(tokens)
9
  # Only return the last_hidden_state
10
  return outputs.last_hidden_state
11
 
test.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
test.py CHANGED
@@ -80,5 +80,6 @@ batch = torch.export.Dim("batch")
80
  dynamic_shapes = {"tokens":{0:batch, 1:token_len}}
81
 
82
  # with torch.no_grad():
83
- export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)
 
84
  # export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)
 
80
  dynamic_shapes = {"tokens":{0:batch, 1:token_len}}
81
 
82
  # with torch.no_grad():
83
+ export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=True)
84
+
85
  # export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)