File size: 1,682 Bytes
dbac20f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from typing import Literal, Optional
import torch
import torch.nn as nn
from mmaudio.ext.autoencoder.vae import VAE, get_my_vae
from mmaudio.ext.bigvgan import BigVGAN
from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
class AutoEncoderModule(nn.Module):
def __init__(self,
*,
vae_ckpt_path,
vocoder_ckpt_path: Optional[str] = None,
mode: Literal['16k', '44k']):
super().__init__()
self.vae: VAE = get_my_vae(mode).eval()
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
self.vae.load_state_dict(vae_state_dict)
self.vae.remove_weight_norm()
if mode == '16k':
assert vocoder_ckpt_path is not None
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
elif mode == '44k':
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
use_cuda_kernel=False)
self.vocoder.remove_weight_norm()
else:
raise ValueError(f'Unknown mode: {mode}')
for param in self.parameters():
param.requires_grad = False
@torch.inference_mode()
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
return self.vae.encode(x)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.vae.decode(z)
@torch.inference_mode()
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
return self.vocoder(spec)
|