ZeroRVC / model.py
JacobLinCool's picture
feat: infer
3a010aa
raw
history blame
686 Bytes
from accelerate import Accelerator
from infer.lib.rmvpe import RMVPE
from fairseq.checkpoint_utils import load_model_ensemble_and_task
accelerator = Accelerator()
device = accelerator.device
print(f"Using device: {device}")
fp16 = accelerator.mixed_precision == "fp16"
print(f"Using fp16: {fp16}")
rmvpe_model_path = "assets/rmvpe/rmvpe.pt"
rmvpe = RMVPE(rmvpe_model_path, is_half=fp16, device=device)
print("RMVPE model loaded.")
hubert_model_path = "assets/hubert/hubert_base.pt"
models, hubert_cfg, _ = load_model_ensemble_and_task([hubert_model_path])
hubert = models[0]
hubert = hubert.to(device)
if fp16:
hubert = hubert.half()
hubert.eval()
print("Hubert model loaded.")