Spaces:
Runtime error
Runtime error
jonathanjordan21
commited on
Commit
•
f28083c
1
Parent(s):
ad6dc1b
Update app.py
Browse files
app.py
CHANGED
@@ -16,6 +16,8 @@ from collections import OrderedDict
|
|
16 |
from onmt_modules.misc import sequence_mask
|
17 |
from model_autopst import Generator_2 as Predictor
|
18 |
from hparams_autopst import hparams
|
|
|
|
|
19 |
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
@@ -71,6 +73,10 @@ model = build_model().to(device)
|
|
71 |
checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
|
72 |
model.load_state_dict(checkpoint["state_dict"])
|
73 |
|
|
|
|
|
|
|
|
|
74 |
# for name, sp in spect_vc.items():
|
75 |
|
76 |
# print(name)
|
@@ -81,57 +87,164 @@ model.load_state_dict(checkpoint["state_dict"])
|
|
81 |
|
82 |
|
83 |
|
84 |
-
def respond(
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
):
|
92 |
-
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
|
100 |
-
|
101 |
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
|
113 |
-
|
114 |
-
|
115 |
|
116 |
"""
|
117 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
118 |
"""
|
119 |
-
demo = gr.ChatInterface(
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
],
|
|
|
133 |
)
|
134 |
|
135 |
|
|
|
136 |
if __name__ == "__main__":
|
137 |
demo.launch()
|
|
|
16 |
from onmt_modules.misc import sequence_mask
|
17 |
from model_autopst import Generator_2 as Predictor
|
18 |
from hparams_autopst import hparams
|
19 |
+
from model_sea import Generator
|
20 |
+
from hparams_sea import hparams as sea_hparams
|
21 |
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
|
|
|
73 |
checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
|
74 |
model.load_state_dict(checkpoint["state_dict"])
|
75 |
|
76 |
+
# sea_checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='sea.ckpt'), map_location=lambda storage, loc: storage)
|
77 |
+
# gen =Generator(sea_hparams)
|
78 |
+
# gen.load_state_dict(sea_checkpoint['model'], strict=True)
|
79 |
+
|
80 |
# for name, sp in spect_vc.items():
|
81 |
|
82 |
# print(name)
|
|
|
87 |
|
88 |
|
89 |
|
90 |
+
# def respond(
|
91 |
+
# message,
|
92 |
+
# history: list[tuple[str, str]],
|
93 |
+
# system_message,
|
94 |
+
# max_tokens,
|
95 |
+
# temperature,
|
96 |
+
# top_p,
|
97 |
+
# ):
|
98 |
+
# messages = [{"role": "system", "content": system_message}]
|
99 |
|
100 |
+
# for val in history:
|
101 |
+
# if val[0]:
|
102 |
+
# messages.append({"role": "user", "content": val[0]})
|
103 |
+
# if val[1]:
|
104 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
105 |
|
106 |
+
# messages.append({"role": "user", "content": message})
|
107 |
|
108 |
+
# response = ""
|
109 |
|
110 |
+
# for message in client.chat_completion(
|
111 |
+
# messages,
|
112 |
+
# max_tokens=max_tokens,
|
113 |
+
# stream=True,
|
114 |
+
# temperature=temperature,
|
115 |
+
# top_p=top_p,
|
116 |
+
# ):
|
117 |
+
# token = message.choices[0].delta.content
|
118 |
|
119 |
+
# response += token
|
120 |
+
# yield response
|
121 |
|
122 |
"""
|
123 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
124 |
"""
|
125 |
+
# demo = gr.ChatInterface(
|
126 |
+
# respond,
|
127 |
+
# additional_inputs=[
|
128 |
+
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
129 |
+
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
130 |
+
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
131 |
+
# gr.Slider(
|
132 |
+
# minimum=0.1,
|
133 |
+
# maximum=1.0,
|
134 |
+
# value=0.95,
|
135 |
+
# step=0.05,
|
136 |
+
# label="Top-p (nucleus sampling)",
|
137 |
+
# ),
|
138 |
+
# ],
|
139 |
+
# )
|
140 |
+
|
141 |
+
import os
|
142 |
+
import pickle
|
143 |
+
import numpy as np
|
144 |
+
import soundfile as sf
|
145 |
+
from scipy import signal
|
146 |
+
from scipy.signal import get_window
|
147 |
+
from librosa.filters import mel
|
148 |
+
from numpy.random import RandomState
|
149 |
+
|
150 |
+
|
151 |
+
def butter_highpass(cutoff, fs, order=5):
|
152 |
+
nyq = 0.5 * fs
|
153 |
+
normal_cutoff = cutoff / nyq
|
154 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
155 |
+
return b, a
|
156 |
+
|
157 |
+
|
158 |
+
def pySTFT(x, fft_length=1024, hop_length=256):
|
159 |
+
|
160 |
+
x = np.pad(x, int(fft_length//2), mode='reflect')
|
161 |
+
|
162 |
+
noverlap = fft_length - hop_length
|
163 |
+
shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
|
164 |
+
strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
|
165 |
+
result = np.lib.stride_tricks.as_strided(x, shape=shape,
|
166 |
+
strides=strides)
|
167 |
+
|
168 |
+
fft_window = get_window('hann', fft_length, fftbins=True)
|
169 |
+
result = np.fft.rfft(fft_window * result, n=fft_length).T
|
170 |
+
|
171 |
+
return np.abs(result)
|
172 |
+
|
173 |
+
|
174 |
+
def create_sp(cep_real, spk_emb):
|
175 |
+
# cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
|
176 |
+
cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
|
177 |
+
len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
|
178 |
+
real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
|
179 |
+
|
180 |
+
# _, spk_emb = dict_test[uttr[1]][uttr[2]]
|
181 |
+
spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
|
185 |
+
real_mask_A,
|
186 |
+
len_real_A,
|
187 |
+
spk_emb_B)
|
188 |
+
|
189 |
+
uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
|
190 |
+
return uttr_tgt
|
191 |
+
|
192 |
+
def create_mel(x):
|
193 |
+
mel_basis = mel(sr=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80).T
|
194 |
+
min_level = np.exp(-100 / 20 * np.log(10))
|
195 |
+
b, a = butter_highpass(30, 16000, order=5)
|
196 |
+
|
197 |
+
mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
|
198 |
+
spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
|
199 |
+
|
200 |
+
if x.shape[0] % 256 == 0:
|
201 |
+
x = np.concatenate((x, np.array([1e-06])), axis=0)
|
202 |
+
y = signal.filtfilt(b, a, x)
|
203 |
+
D = pySTFT(y * 0.96).T
|
204 |
+
D_mel = np.dot(D, mel_basis)
|
205 |
+
D_db = 20 * np.log10(np.maximum(min_level, D_mel))
|
206 |
+
|
207 |
+
# mel sp
|
208 |
+
S = (D_db + 80) / 100
|
209 |
+
|
210 |
+
# mel cep
|
211 |
+
cc_tmp = S.dot(dctmx)
|
212 |
+
cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
|
213 |
+
S = np.clip(S, 0, 1)
|
214 |
+
|
215 |
+
# teacher code
|
216 |
+
# cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
|
217 |
+
# with torch.no_grad():
|
218 |
+
# codes = gen.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
|
219 |
+
return S, cc_norm
|
220 |
+
|
221 |
+
def transcribe(audio, spk):
|
222 |
+
sr, y = audio
|
223 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
|
224 |
+
y = y.astype(np.float32)
|
225 |
+
y /= np.max(np.abs(y))
|
226 |
+
|
227 |
+
spk_emb = np.zeros((82,))
|
228 |
+
spk_emb[spk-1] = 1
|
229 |
+
|
230 |
+
mel_sp, mel_cep = create_mel(y)
|
231 |
+
sp = create_sp(mel_cep, spk_emb)
|
232 |
+
waveform = wavegen(model, c=sp)
|
233 |
+
return 16000, waveform.numpy()
|
234 |
+
|
235 |
+
# return transcriber({"sampling_rate": sr, "raw": y})["text"]
|
236 |
+
|
237 |
+
|
238 |
+
demo = gr.Interface(
|
239 |
+
transcribe,
|
240 |
+
[
|
241 |
+
gr.Audio(),
|
242 |
+
gr.Slider(1, 82, value=21, label="Count", info="Choose between 1 and 82")
|
243 |
],
|
244 |
+
"audio",
|
245 |
)
|
246 |
|
247 |
|
248 |
+
|
249 |
if __name__ == "__main__":
|
250 |
demo.launch()
|