Spaces:
Running
Running
File size: 5,532 Bytes
dc13618 14e19a5 dc13618 14e19a5 dc13618 b0f5083 14e19a5 dc13618 14e19a5 b0f5083 dc13618 b0f5083 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 b0f5083 dc13618 b0f5083 dc13618 14e19a5 dc13618 b0f5083 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 14e19a5 dc13618 14e19a5 b0f5083 dc13618 14e19a5 dc13618 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
import torch
from bert_vits2 import commons
from bert_vits2 import utils as bert_vits2_utils
from bert_vits2.models import SynthesizerTrn
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
from contants import ModelType
from utils import classify_language, get_hparams_from_file, lang_dict
from utils.sentence import sentence_split_and_markup, cut
class Bert_VITS2:
def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
self.symbols = symbols
# Compatible with legacy versions
self.version = process_legacy_versions(self.hps_ms)
if self.version in ["1.0", "1.0.0", "1.0.1"]:
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3
elif self.version in ["1.1.0-transition"]:
self.hps_ms.model.n_layers_trans_flow = 3
elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6
key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value
self.lang = lang_dict.get(key, ["unknown"])
self.bert_handler = BertHandler(self.lang)
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self.net_g = SynthesizerTrn(
len(self.symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
**self.hps_ms.model).to(device)
_ = self.net_g.eval()
self.device = device
self.load_model(model)
def load_model(self, model):
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, version=self.version)
def get_speakers(self):
return self.speakers
@property
def sampling_rate(self):
return self.hps_ms.data.sampling_rate
def get_text(self, text, language_str, hps):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert = self.bert_handler.get_bert(norm_text, word2ph, language_str)
del word2ph
assert bert.shape[-1] == len(phone), phone
if language_str == "zh":
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str == "ja":
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, phone, tone, language
def infer(self, text, lang, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
bert, ja_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms)
with torch.no_grad():
x_tst = phones.to(self.device).unsqueeze(0)
tones = tones.to(self.device).unsqueeze(0)
lang_ids = lang_ids.to(self.device).unsqueeze(0)
bert = bert.to(self.device).unsqueeze(0)
ja_bert = ja_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
speakers = torch.LongTensor([int(sid)]).to(self.device)
audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, ja_bert, sdp_ratio=sdp_ratio
, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[
0][0, 0].data.cpu().float().numpy()
torch.cuda.empty_cache()
return audio
def get_audio(self, voice, auto_break=False):
text = voice.get("text", None)
lang = voice.get("lang", "auto")
sdp_ratio = voice.get("sdp_ratio", 0.2)
noise_scale = voice.get("noise", 0.5)
noise_scale_w = voice.get("noisew", 0.6)
length_scale = voice.get("length", 1)
sid = voice.get("id", 0)
max = voice.get("max", 50)
# sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
if lang == "auto":
lang = classify_language(text, target_languages=self.lang)
sentence_list = cut(text, max)
audios = []
for sentence in sentence_list:
audio = self.infer(sentence, lang, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid)
audios.append(audio)
audio = np.concatenate(audios)
return audio
|