jonathanjordan21 commited on
Commit
f28083c
1 Parent(s): ad6dc1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -39
app.py CHANGED
@@ -16,6 +16,8 @@ from collections import OrderedDict
16
  from onmt_modules.misc import sequence_mask
17
  from model_autopst import Generator_2 as Predictor
18
  from hparams_autopst import hparams
 
 
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
@@ -71,6 +73,10 @@ model = build_model().to(device)
71
  checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
72
  model.load_state_dict(checkpoint["state_dict"])
73
 
 
 
 
 
74
  # for name, sp in spect_vc.items():
75
 
76
  # print(name)
@@ -81,57 +87,164 @@ model.load_state_dict(checkpoint["state_dict"])
81
 
82
 
83
 
84
- def respond(
85
- message,
86
- history: list[tuple[str, str]],
87
- system_message,
88
- max_tokens,
89
- temperature,
90
- top_p,
91
- ):
92
- messages = [{"role": "system", "content": system_message}]
93
 
94
- for val in history:
95
- if val[0]:
96
- messages.append({"role": "user", "content": val[0]})
97
- if val[1]:
98
- messages.append({"role": "assistant", "content": val[1]})
99
 
100
- messages.append({"role": "user", "content": message})
101
 
102
- response = ""
103
 
104
- for message in client.chat_completion(
105
- messages,
106
- max_tokens=max_tokens,
107
- stream=True,
108
- temperature=temperature,
109
- top_p=top_p,
110
- ):
111
- token = message.choices[0].delta.content
112
 
113
- response += token
114
- yield response
115
 
116
  """
117
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
118
  """
119
- demo = gr.ChatInterface(
120
- respond,
121
- additional_inputs=[
122
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
123
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
124
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
125
- gr.Slider(
126
- minimum=0.1,
127
- maximum=1.0,
128
- value=0.95,
129
- step=0.05,
130
- label="Top-p (nucleus sampling)",
131
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ],
 
133
  )
134
 
135
 
 
136
  if __name__ == "__main__":
137
  demo.launch()
 
16
  from onmt_modules.misc import sequence_mask
17
  from model_autopst import Generator_2 as Predictor
18
  from hparams_autopst import hparams
19
+ from model_sea import Generator
20
+ from hparams_sea import hparams as sea_hparams
21
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
 
73
  checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
74
  model.load_state_dict(checkpoint["state_dict"])
75
 
76
+ # sea_checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='sea.ckpt'), map_location=lambda storage, loc: storage)
77
+ # gen =Generator(sea_hparams)
78
+ # gen.load_state_dict(sea_checkpoint['model'], strict=True)
79
+
80
  # for name, sp in spect_vc.items():
81
 
82
  # print(name)
 
87
 
88
 
89
 
90
+ # def respond(
91
+ # message,
92
+ # history: list[tuple[str, str]],
93
+ # system_message,
94
+ # max_tokens,
95
+ # temperature,
96
+ # top_p,
97
+ # ):
98
+ # messages = [{"role": "system", "content": system_message}]
99
 
100
+ # for val in history:
101
+ # if val[0]:
102
+ # messages.append({"role": "user", "content": val[0]})
103
+ # if val[1]:
104
+ # messages.append({"role": "assistant", "content": val[1]})
105
 
106
+ # messages.append({"role": "user", "content": message})
107
 
108
+ # response = ""
109
 
110
+ # for message in client.chat_completion(
111
+ # messages,
112
+ # max_tokens=max_tokens,
113
+ # stream=True,
114
+ # temperature=temperature,
115
+ # top_p=top_p,
116
+ # ):
117
+ # token = message.choices[0].delta.content
118
 
119
+ # response += token
120
+ # yield response
121
 
122
  """
123
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
124
  """
125
+ # demo = gr.ChatInterface(
126
+ # respond,
127
+ # additional_inputs=[
128
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
129
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
130
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
131
+ # gr.Slider(
132
+ # minimum=0.1,
133
+ # maximum=1.0,
134
+ # value=0.95,
135
+ # step=0.05,
136
+ # label="Top-p (nucleus sampling)",
137
+ # ),
138
+ # ],
139
+ # )
140
+
141
+ import os
142
+ import pickle
143
+ import numpy as np
144
+ import soundfile as sf
145
+ from scipy import signal
146
+ from scipy.signal import get_window
147
+ from librosa.filters import mel
148
+ from numpy.random import RandomState
149
+
150
+
151
+ def butter_highpass(cutoff, fs, order=5):
152
+ nyq = 0.5 * fs
153
+ normal_cutoff = cutoff / nyq
154
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
155
+ return b, a
156
+
157
+
158
+ def pySTFT(x, fft_length=1024, hop_length=256):
159
+
160
+ x = np.pad(x, int(fft_length//2), mode='reflect')
161
+
162
+ noverlap = fft_length - hop_length
163
+ shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
164
+ strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
165
+ result = np.lib.stride_tricks.as_strided(x, shape=shape,
166
+ strides=strides)
167
+
168
+ fft_window = get_window('hann', fft_length, fftbins=True)
169
+ result = np.fft.rfft(fft_window * result, n=fft_length).T
170
+
171
+ return np.abs(result)
172
+
173
+
174
+ def create_sp(cep_real, spk_emb):
175
+ # cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
176
+ cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
177
+ len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
178
+ real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
179
+
180
+ # _, spk_emb = dict_test[uttr[1]][uttr[2]]
181
+ spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
182
+
183
+ with torch.no_grad():
184
+ spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
185
+ real_mask_A,
186
+ len_real_A,
187
+ spk_emb_B)
188
+
189
+ uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
190
+ return uttr_tgt
191
+
192
+ def create_mel(x):
193
+ mel_basis = mel(sr=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80).T
194
+ min_level = np.exp(-100 / 20 * np.log(10))
195
+ b, a = butter_highpass(30, 16000, order=5)
196
+
197
+ mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
198
+ spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
199
+
200
+ if x.shape[0] % 256 == 0:
201
+ x = np.concatenate((x, np.array([1e-06])), axis=0)
202
+ y = signal.filtfilt(b, a, x)
203
+ D = pySTFT(y * 0.96).T
204
+ D_mel = np.dot(D, mel_basis)
205
+ D_db = 20 * np.log10(np.maximum(min_level, D_mel))
206
+
207
+ # mel sp
208
+ S = (D_db + 80) / 100
209
+
210
+ # mel cep
211
+ cc_tmp = S.dot(dctmx)
212
+ cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
213
+ S = np.clip(S, 0, 1)
214
+
215
+ # teacher code
216
+ # cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
217
+ # with torch.no_grad():
218
+ # codes = gen.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
219
+ return S, cc_norm
220
+
221
+ def transcribe(audio, spk):
222
+ sr, y = audio
223
+ y = librosa.resample(y, orig_sr=sr, target_sr=16000)
224
+ y = y.astype(np.float32)
225
+ y /= np.max(np.abs(y))
226
+
227
+ spk_emb = np.zeros((82,))
228
+ spk_emb[spk-1] = 1
229
+
230
+ mel_sp, mel_cep = create_mel(y)
231
+ sp = create_sp(mel_cep, spk_emb)
232
+ waveform = wavegen(model, c=sp)
233
+ return 16000, waveform.numpy()
234
+
235
+ # return transcriber({"sampling_rate": sr, "raw": y})["text"]
236
+
237
+
238
+ demo = gr.Interface(
239
+ transcribe,
240
+ [
241
+ gr.Audio(),
242
+ gr.Slider(1, 82, value=21, label="Count", info="Choose between 1 and 82")
243
  ],
244
+ "audio",
245
  )
246
 
247
 
248
+
249
  if __name__ == "__main__":
250
  demo.launch()