Spaces:
Runtime error
Runtime error
import torch | |
from torchaudio.models.wav2vec2.utils import import_fairseq_model | |
from fairseq import checkpoint_utils | |
from onnxexport.model_onnx import SynthesizerTrn | |
import utils | |
def get_hubert_model(): | |
vec_path = "hubert/checkpoint_best_legacy_500.pt" | |
print("load model(s) from {}".format(vec_path)) | |
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
[vec_path], | |
suffix="", | |
) | |
model = models[0] | |
model.eval() | |
return model | |
def main(HubertExport, NetExport): | |
path = "SoVits4.0" | |
'''if HubertExport: | |
device = torch.device("cpu") | |
vec_path = "hubert/checkpoint_best_legacy_500.pt" | |
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
[vec_path], | |
suffix="", | |
) | |
original = models[0] | |
original.eval() | |
model = original | |
test_input = torch.rand(1, 1, 16000) | |
model(test_input) | |
torch.onnx.export(model, | |
test_input, | |
"hubert4.0.onnx", | |
export_params=True, | |
opset_version=16, | |
do_constant_folding=True, | |
input_names=['source'], | |
output_names=['embed'], | |
dynamic_axes={ | |
'source': | |
{ | |
2: "sample_length" | |
}, | |
} | |
)''' | |
if NetExport: | |
device = torch.device("cpu") | |
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") | |
SVCVITS = SynthesizerTrn( | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
**hps.model) | |
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) | |
_ = SVCVITS.eval().to(device) | |
for i in SVCVITS.parameters(): | |
i.requires_grad = False | |
test_hidden_unit = torch.rand(1, 10, 256) | |
test_pitch = torch.rand(1, 10) | |
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) | |
test_uv = torch.ones(1, 10, dtype=torch.float32) | |
test_noise = torch.randn(1, 192, 10) | |
test_sid = torch.LongTensor([0]) | |
input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] | |
output_names = ["audio", ] | |
SVCVITS.eval() | |
torch.onnx.export(SVCVITS, | |
( | |
test_hidden_unit.to(device), | |
test_pitch.to(device), | |
test_mel2ph.to(device), | |
test_uv.to(device), | |
test_noise.to(device), | |
test_sid.to(device) | |
), | |
f"checkpoints/{path}/model.onnx", | |
dynamic_axes={ | |
"c": [0, 1], | |
"f0": [1], | |
"mel2ph": [1], | |
"uv": [1], | |
"noise": [2], | |
}, | |
do_constant_folding=False, | |
opset_version=16, | |
verbose=False, | |
input_names=input_names, | |
output_names=output_names) | |
if __name__ == '__main__': | |
main(False, True) | |