|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler |
|
|
|
from models.svc.base import SVCInference |
|
from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline |
|
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper |
|
from modules.encoder.condition_encoder import ConditionEncoder |
|
|
|
|
|
class DiffusionInference(SVCInference): |
|
def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
|
SVCInference.__init__(self, args, cfg, infer_type) |
|
|
|
settings = { |
|
**cfg.model.diffusion.scheduler_settings, |
|
**cfg.inference.diffusion.scheduler_settings, |
|
} |
|
settings.pop("num_inference_timesteps") |
|
|
|
if cfg.inference.diffusion.scheduler.lower() == "ddpm": |
|
self.scheduler = DDPMScheduler(**settings) |
|
self.logger.info("Using DDPM scheduler.") |
|
elif cfg.inference.diffusion.scheduler.lower() == "ddim": |
|
self.scheduler = DDIMScheduler(**settings) |
|
self.logger.info("Using DDIM scheduler.") |
|
elif cfg.inference.diffusion.scheduler.lower() == "pndm": |
|
self.scheduler = PNDMScheduler(**settings) |
|
self.logger.info("Using PNDM scheduler.") |
|
else: |
|
raise NotImplementedError( |
|
"Unsupported scheduler type: {}".format( |
|
cfg.inference.diffusion.scheduler.lower() |
|
) |
|
) |
|
|
|
self.pipeline = DiffusionInferencePipeline( |
|
self.model[1], |
|
self.scheduler, |
|
args.diffusion_inference_steps, |
|
) |
|
|
|
def _build_model(self): |
|
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min |
|
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max |
|
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) |
|
self.acoustic_mapper = DiffusionWrapper(self.cfg) |
|
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) |
|
return model |
|
|
|
def _inference_each_batch(self, batch_data): |
|
device = self.accelerator.device |
|
for k, v in batch_data.items(): |
|
batch_data[k] = v.to(device) |
|
|
|
conditioner = self.model[0](batch_data) |
|
noise = torch.randn_like(batch_data["mel"], device=device) |
|
y_pred = self.pipeline(noise, conditioner) |
|
return y_pred |
|
|