ikechan8370
commited on
Commit
•
23aa815
1
Parent(s):
b772f7c
fix: add support for gpu
Browse files
models.py
CHANGED
@@ -496,9 +496,10 @@ class SynthesizerTrn(nn.Module):
|
|
496 |
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
497 |
|
498 |
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
499 |
-
|
|
|
500 |
if self.n_speakers > 0:
|
501 |
-
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
502 |
else:
|
503 |
g = None
|
504 |
|
|
|
496 |
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
497 |
|
498 |
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
499 |
+
device = next(self.parameters()).device # 获取模型所在的设备
|
500 |
+
x, m_p, logs_p, x_mask = self.enc_p(x.to(device), x_lengths.to(device))
|
501 |
if self.n_speakers > 0:
|
502 |
+
g = self.emb_g(sid.to(device)).unsqueeze(-1) # [b, h, 1]
|
503 |
else:
|
504 |
g = None
|
505 |
|