Spaces:
Runtime error
Runtime error
File size: 3,559 Bytes
de0ac05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
from torchaudio.models.wav2vec2.utils import import_fairseq_model
from fairseq import checkpoint_utils
from onnxexport.model_onnx import SynthesizerTrn
import utils
def get_hubert_model():
vec_path = "hubert/checkpoint_best_legacy_500.pt"
print("load model(s) from {}".format(vec_path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[vec_path],
suffix="",
)
model = models[0]
model.eval()
return model
def main(HubertExport, NetExport):
path = "SoVits4.0"
'''if HubertExport:
device = torch.device("cpu")
vec_path = "hubert/checkpoint_best_legacy_500.pt"
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[vec_path],
suffix="",
)
original = models[0]
original.eval()
model = original
test_input = torch.rand(1, 1, 16000)
model(test_input)
torch.onnx.export(model,
test_input,
"hubert4.0.onnx",
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=['source'],
output_names=['embed'],
dynamic_axes={
'source':
{
2: "sample_length"
},
}
)'''
if NetExport:
device = torch.device("cpu")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
test_hidden_unit = torch.rand(1, 10, 256)
test_pitch = torch.rand(1, 10)
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
test_uv = torch.ones(1, 10, dtype=torch.float32)
test_noise = torch.randn(1, 192, 10)
test_sid = torch.LongTensor([0])
input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
output_names = ["audio", ]
SVCVITS.eval()
torch.onnx.export(SVCVITS,
(
test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device)
),
f"checkpoints/{path}/model.onnx",
dynamic_axes={
"c": [0, 1],
"f0": [1],
"mel2ph": [1],
"uv": [1],
"noise": [2],
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names)
if __name__ == '__main__':
main(False, True)
|