Text-to-Speech
English
geneing commited on
Commit
eb932f8
·
1 Parent(s): 5b93bbf

Added script for testing onnx export.

Browse files
Files changed (2) hide show
  1. test.ipynb +0 -0
  2. test.py +84 -0
test.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TORCH_LOGS'] = '+dynamic'
3
+ os.environ['TORCH_LOGS'] = '+export'
4
+ os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']="u0 >= 0"
5
+ # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']="1"
6
+ os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']="u0"
7
+
8
+
9
+ from kokoro import phonemize, tokenize, length_to_mask
10
+ import torch.nn.functional as F
11
+ from models import build_model
12
+ import torch
13
+ device = "cpu" #'cuda' if torch.cuda.is_available() else 'cpu'
14
+ MODEL = build_model('kokoro-v0_19.pth', device)
15
+ voicepack = torch.load('voices/af.pt', weights_only=True).to(device)
16
+
17
+ model = MODEL
18
+ speed = 1.
19
+
20
+ text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
21
+
22
+ ps = phonemize(text, "a")
23
+ tokens = tokenize(ps)
24
+
25
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
26
+
27
+ class StyleTTS2(torch.nn.Module):
28
+ def __init__(self, model, voicepack):
29
+ super().__init__()
30
+ self.model = model
31
+ self.voicepack = voicepack
32
+
33
+ def forward(self, tokens):
34
+ speed = 1.
35
+ # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))
36
+ device = tokens.device
37
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
38
+
39
+ text_mask = length_to_mask(input_lengths).to(device)
40
+ bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())
41
+
42
+ d_en = self.model["bert_encoder"](bert_dur).transpose(-1, -2)
43
+
44
+ ref_s = self.voicepack[tokens.shape[1]]
45
+ s = ref_s[:, 128:]
46
+
47
+ d = self.model["predictor"].text_encoder.inference(d_en, s)
48
+ x, _ = self.model["predictor"].lstm(d)
49
+
50
+ duration = self.model["predictor"].duration_proj(x)
51
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
52
+ pred_dur = torch.round(duration).clamp(min=1).long()
53
+
54
+ c_start = F.pad(pred_dur,(1,0), "constant").cumsum(dim=1)[0,0:-1]
55
+ c_end = c_start + pred_dur[0,:]
56
+
57
+ torch._check(pred_dur.sum().item()>0, lambda: print(f"Got {pred_dur.sum().item()}"))
58
+ indices = torch.arange(0, pred_dur.sum().item()).long().to(device)
59
+
60
+ pred_aln_trg_list=[]
61
+ for cs, ce in zip(c_start, c_end):
62
+ row = torch.where((indices>=cs) & (indices<ce), 1., 0.)
63
+ pred_aln_trg_list.append(row)
64
+ pred_aln_trg=torch.vstack(pred_aln_trg_list)
65
+
66
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
67
+
68
+ F0_pred, N_pred = self.model["predictor"].F0Ntrain(en, s)
69
+ t_en = self.model["text_encoder"].inference(tokens)
70
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
71
+ return (asr, F0_pred, N_pred, ref_s[:, :128])
72
+ # output = self.model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()
73
+
74
+
75
+ style_model = StyleTTS2(model=model, voicepack=voicepack)
76
+ (asr, F0_pred, N_pred, ref_s) = style_model(tokens)
77
+
78
+ token_len = torch.export.Dim("token_len", min=2, max=510)
79
+ batch = torch.export.Dim("batch")
80
+ dynamic_shapes = {"tokens":{0:batch, 1:token_len}}
81
+
82
+ # with torch.no_grad():
83
+ export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)
84
+ # export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)