from typing import Literal, Optional import torch import torch.nn as nn from mmaudio.ext.autoencoder.vae import VAE, get_my_vae from mmaudio.ext.bigvgan import BigVGAN from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2 from mmaudio.model.utils.distributions import DiagonalGaussianDistribution class AutoEncoderModule(nn.Module): def __init__(self, *, vae_ckpt_path, vocoder_ckpt_path: Optional[str] = None, mode: Literal['16k', '44k'], need_vae_encoder: bool = True): super().__init__() self.vae: VAE = get_my_vae(mode).eval() vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') self.vae.load_state_dict(vae_state_dict, strict=False) self.vae.remove_weight_norm() if mode == '16k': assert vocoder_ckpt_path is not None self.vocoder = BigVGAN(vocoder_ckpt_path).eval() elif mode == '44k': self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False) self.vocoder.remove_weight_norm() else: raise ValueError(f'Unknown mode: {mode}') for param in self.parameters(): param.requires_grad = False if not need_vae_encoder: del self.vae.encoder @torch.inference_mode() def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: return self.vae.encode(x) @torch.inference_mode() def decode(self, z: torch.Tensor) -> torch.Tensor: return self.vae.decode(z) @torch.inference_mode() def vocode(self, spec: torch.Tensor) -> torch.Tensor: return self.vocoder(spec)