jonathanjordan21 commited on
Commit
06220ce
1 Parent(s): 7ea64f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -2
app.py CHANGED
@@ -1,10 +1,84 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def respond(
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+
4
 
5
  """
6
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
7
  """
8
+
9
+
10
+ import os
11
+ import pickle
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from collections import OrderedDict
16
+ from AutoPST.onmt_modules.misc import sequence_mask
17
+ from AutoPST.model_autopst import Generator_2 as Predictor
18
+ from AutoPST.hparams_autopst import hparams
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ P = Predictor(hparams).eval().to(device)
23
+
24
+ checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='580000-P.ckpt'), map_location=lambda storage, loc: storage)
25
+ P.load_state_dict(checkpoint['model'], strict=True)
26
+ print('Loaded predictor .....................................................')
27
+
28
+ dict_test = pickle.load(open('./AutoPST/assets/test_vctk.meta', 'rb'))
29
+
30
+ spect_vc = OrderedDict()
31
+
32
+ uttrs = [('p231', 'p270', '001'),
33
+ ('p270', 'p231', '001'),
34
+ ('p231', 'p245', '003001'),
35
+ ('p245', 'p231', '003001'),
36
+ ('p239', 'p270', '024002'),
37
+ ('p270', 'p239', '024002')]
38
+
39
+
40
+ for uttr in uttrs:
41
+
42
+ cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
43
+ cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
44
+ len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
45
+ real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
46
+
47
+ _, spk_emb = dict_test[uttr[1]][uttr[2]]
48
+ spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
49
+
50
+ with torch.no_grad():
51
+ spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
52
+ real_mask_A,
53
+ len_real_A,
54
+ spk_emb_B)
55
+
56
+ uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
57
+
58
+ spect_vc[f'{uttr[0]}_{uttr[1]}_{uttr[2]}'] = uttr_tgt
59
+
60
+ # spectrogram to waveform
61
+ # Feel free to use other vocoders
62
+ # This cell requires some preparation to work, please see the corresponding part in AutoVC
63
+ import torch
64
+ import librosa
65
+ import pickle
66
+ import os
67
+ from AutoPST.synthesis import build_model
68
+ from AutoPST.synthesis import wavegen
69
+
70
+ model = build_model().to(device)
71
+ checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
72
+ model.load_state_dict(checkpoint["state_dict"])
73
+
74
+ # for name, sp in spect_vc.items():
75
+
76
+ # print(name)
77
+ # waveform = wavegen(model, c=sp)
78
+
79
+ # librosa.output.write_wav('./assets/'+name+'.wav', waveform, sr=16000)
80
+
81
+
82
 
83
 
84
  def respond(