Attentionless VOcoder Streaming

Usage

from huggingface_hub import hf_hub_download
import soundfile
import torch
from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig
from torch import nn
import torch.nn.functional as F



class Voc(Wav2Vec2PreTrainedModel):

    '''For using different batch_siz -> Voc._flush()
    '''

    def __init__(self, config=PretrainedConfig()):
        super().__init__(config=config)
        self.encoder_transformer = VocTransformer()
        self.decoder_transformer = VocTransformer()
        self.encoder = SEANetEncoder()
        self.decoder = SEANetDecoder()
        self.sample_rate = 24000
        self.quantizer = SplitResidualVectorQuantizer()
        self.downsample = BufferConv1d(512, 512, kernel_size=4, stride=2, groups=1, bias=False)
        upsample_channel_wise_bug = True
        self.upsample = BufferConvTranspose1d(512, 512, kernel_size=4,
                                              groups=512 if upsample_channel_wise_bug else 1,
                                              stride=2, bias=False)
        self.frame_rate = 12.5
        self.encode_buffer = None

    def _flush(self):
        '''stream buffers have tensors of old batch size! Voc()._flush() to clean buffers
        '''
        self.encode_buffer = None # holds unused (incomplete windows of len < 1920) - we need 1920 to produce 1 token
        if self.downsample.previous is not None:
            self.downsample.previous = None
        if self.upsample.partial is not None:
            self.upsample.partial = None
        for arch in [self.encoder, self.decoder]:
            for _m in arch.model:
                if type(_m) is SEANetResnetBlock:
                    for _b in _m.block:
                        if type(_b) is BufferConv1d:
                            if _b.previous is not None:
                                _b.previous = None
                if type(_m) is BufferConv1d:
                    if _m.previous is not None:
                        _m.previous = None
                if type(_m) is BufferConvTranspose1d:
                    if _m.partial is not None:
                        _m.partial = None

    @torch.no_grad()
    def encode(self, x):
        '''24KHz audio to codes
           x : [bs, 1, 24 KHz]
           c : [bs, 8, time]      = 1920 audio samples produce 1 time frame (of n_q codebooks)
        '''
        if self.encode_buffer is not None:
            x = torch.cat([self.encode_buffer, x], 2)
        _bs, _1, _len = x.shape
        num_frames = int(_len / 1920)
        leftover = x[:, :, (num_frames+1) * 1920:]
        if leftover.shape[2] > 0:
            self.encode_buffer = leftover
        else:
            self.encode_buffer = None
            torch.cuda.empty_cache()
        if num_frames > 0:
            c = []
            for n in range(num_frames):
                e = self.encoder(x[:, :, n * 1920:(n + 1) * 1920])
                e = self.encoder_transformer(e)
                e = self.downsample(e)
                _c = self.quantizer.encode(e)
                c.append(_c)
            c = torch.cat(c, 2)
        else:
            # num_frames = 0  Early exit -> for x.shape[2]<1920 fill conv buffers but can't output token
            c = torch.empty(_bs, 0, self.n_q)
        return c

    @torch.no_grad()
    def decode(self, c):
        '''codes to 24kHZ audio
           c: [bs, 8, n_tokens]
           x: [bs, 1, n_tokens * 1920]
        '''
        _hidden = []
        for i in range(c.shape[2]):
            x = self.quantizer.decode(c[:, :, i:i+1])
            x = self.upsample(x)
            x = self.decoder_transformer(x)
            x = self.decoder(x)
            _hidden.append(x)
        return torch.cat(_hidden, 2)  # [bs, 1, 24KHz]


class SEANetResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        kernel_sizes=[3, 1],
    ):
        super().__init__()

        block = []
        for i, kernel_size in enumerate(kernel_sizes):

            block += [
                nn.ELU(),
                BufferConv1d(
                    dim      if i == 0 else dim // 2,
                    dim // 2 if i == 0 else dim,
                    kernel_size=kernel_size,
                    bias=True,
                ),
            ]

        self.block = nn.Sequential(*block)

    def forward(self, x):
        return x + self.block(x)


class SEANetEncoder(nn.Module):
    def __init__(
        self,
        channels=1, # DOES NOT SUPPORT STEREO
        dimension=512,
        n_filters=64,
        ratios=[8, 6, 5, 4],
        kernel_size=7,
        last_kernel_size=3,
    ):
        super().__init__()
        self.ratios = list(reversed(ratios))
        del ratios
        mult = 1
        model=[
            BufferConv1d(
                channels,
                mult * n_filters,
                kernel_size,
                bias=True
            )
        ]
        for i, ratio in enumerate(self.ratios):
            model += [SEANetResnetBlock(mult * n_filters),
                      nn.ELU(),
                      BufferConv1d(mult * n_filters,
                                      mult * n_filters * 2,
                                      kernel_size=ratio * 2,
                                      stride=ratio,
                                      bias=True)]
            mult *= 2
        # ENDFOR
        model += [nn.ELU(),
                  BufferConv1d(mult * n_filters,
                                    dimension,
                                    last_kernel_size,
                                    bias=True)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class SEANetDecoder(nn.Module):

    def __init__(
        self,
        channels=1,
        dimension=512,
        n_filters=64,
        ratios=[8, 6, 5, 4],
        kernel_size=7,
        last_kernel_size=3):

        super().__init__()
        mult = int(2 ** len(ratios))
        model = [BufferConv1d(dimension,
                                 mult * n_filters,
                                 kernel_size,
                                 bias=True)]
        #UP
        for i, ratio in enumerate(ratios):
            model += [nn.ELU(),
                      BufferConvTranspose1d(mult * n_filters,
                                        mult * n_filters // 2,
                                        kernel_size=ratio * 2,
                                        stride=ratio,
                                        bias=True),
                      SEANetResnetBlock(mult * n_filters // 2)]
            mult //= 2
        # LAST
        model += [
            nn.ELU(),
            BufferConv1d(
                n_filters,
                channels,
                last_kernel_size,
                bias=True
            ),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class BufferConv1d(nn.Conv1d):
    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.previous = None

    def forward(self, x):
        k = self.kernel_size[0]

        if self.previous is not None:

                x = torch.cat([self.previous, x], 2)

        else:  # If self.previous is None => Use zero pad

            if k == 3:

                p = (2, 0)
                x = F.pad(x, p, mode='replicate', value=0.0) # skip connections SeaNetResBlk

            elif k == 4:  # ConvTrUpsample is the first conv encountered by decode replicate solves pulse

                p = (3, 0)
                x = F.pad(x, p, mode='replicate', value=0.0)

            elif k == 7:

                p = (6, 0)
                x = F.pad(x, p, mode='replicate', value=0.0)

            elif k == 16:

                p = (2, 0)
                x = F.pad(x, p, mode='replicate', value=0.0)  # THis can be also constant w/o pulse occur

        num_frames = int(         (x.shape[2] - self.kernel_size[0])  / self.stride[0] )  +  1  # +1 is: k starts at left of x and doing (I-k)/s jumps
        offset = num_frames * self.stride[0]
        self.previous = x[..., offset:]
        return super().forward(x)


class BufferConvTranspose1d(nn.ConvTranspose1d):
    # kernel 5 has only 1 pixel for input (cloned)
    # https://distill.pub/2016/deconv-checkerboard/
    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args,
                         **kwargs)
        self.partial = None
        
    def forward(self, x):
        out = super().forward(x)
        OT = out.shape[2]
        invalid_steps = self.kernel_size[0] - self.stride[0]
        if self.partial is not None:
            PT = self.partial.shape[-1]
            if self.bias is not None:
                out[..., :PT] += self.partial - self.bias[:, None]
            else:
                out[..., :PT] += self.partial  # for ConvTrUpsample1d
        invalid_steps = self.kernel_size[0] - self.stride[0]
        self.partial = out[..., OT - invalid_steps :]
        out = out[...,:OT - invalid_steps]
        return out


class CodeBook(nn.Module):
    def __init__(self, dim, codebook_size):
        super().__init__()
        self.register_buffer('_e', torch.zeros(codebook_size, dim))

    def encode(self, x):
        dist = torch.cdist(
            x.transpose(1, 2), # [bs, time, 256]
            self._e[None, :, :]  # [1, 2048, 256]
        )
        codes = dist.argmin(2)
        return codes

    def decode(self, codes):
        quantized = F.embedding(codes, self._e)
        return quantized.transpose(1, 2)  # [1, 256, time]


class SplitResidualVectorQuantizer(nn.Module):

    def __init__(self,
                 n_q=None):
        super().__init__()
        self.in_proj_s  = torch.nn.Conv1d(512, 256, 1, bias=False)
        self.in_proj_a  = torch.nn.Conv1d(512, 256, 1, bias=False)
        self.out_proj_s = torch.nn.Conv1d(256, 512, 1, bias=False) # reused for all _acoustic_books
        self.out_proj_a = torch.nn.Conv1d(256, 512, 1, bias=False)
        self.layers = nn.ModuleList([CodeBook(dim=256, codebook_size=2048) for _ in range(18)])
        # self._acoustic_books = range(1, 16)  # Official Mimi
        # CODE BOOKS
        # Here we re use RVQ codebooks for higher fidelity!
        # Exclude 0 here as it has different proj (in_proj_s)
        self._acoustic_books = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 17, 17, 17, 17]

    def encode(self, x):
        indices = self.layers[0].encode(self.in_proj_s(x))  # integers
        all_indices = [ indices[:, None, :], ]
        x = self.in_proj_a(x)
        for _cb in self._acoustic_books:
            indices = self.layers[_cb].encode(x)
            x = x - self.layers[_cb].decode(indices)
            all_indices.append(indices[:, None, :])
        codes = torch.cat(all_indices, 1)
        return codes

    def decode(self, codes):
        _s = self.layers[0].decode(codes[:, 0, :])
        _a = torch.zeros([1, 1], device=codes.device)
        for i, _cb in enumerate(self._acoustic_books):
            _a = _a + self.layers[_cb].decode(codes[:, i+1, :])
        return self.out_proj_s(_s) + self.out_proj_a(_a)   # [bs, 512, time]


class VocAttention(nn.Module):

    def __init__(self,
                 embed_dim):

        super().__init__()
        self.fused_proj = nn.Parameter(torch.zeros(embed_dim, embed_dim))

    def forward(self, x):
        '''bypass of streaming training'''
        if x.shape[1] > 1:
            x = x.mean(1, keepdims=True)
        x = torch.matmul(x, self.fused_proj)
        return x  # FFN broadcasts to x.shape[1]=2


class VocTransformerLayer(nn.Module):

    def __init__(self, d_model=512, dim_feedforward=2048):
        super().__init__()
        self.self_attn = VocAttention(embed_dim=d_model)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)

    def forward(self, x):
        x = x + self.self_attn(self.norm1(x))
        return x + self.linear2(F.gelu(self.linear1(self.norm2(x))))


class VocTransformer(nn.Module):

    def __init__(self):

        super().__init__()
        self.layers = nn.ModuleList(VocTransformerLayer() for _ in range(8))

    def forward(self, x):
        x = x.transpose(1, 2)
        for la in self.layers:
            x = la(x)
        return x.transpose(1, 2)

device = 'cpu'  #'cuda:0'
model = Voc.from_pretrained('ivao0/voc').to(device)
x, _ = soundfile.read(hf_hub_download(repo_id='ivao0/voc', filename='true.wav'))  # 24 KHz
x = torch.from_numpy(x[None, None, :]).to(dtype=torch.float, device=device)
codes = model.encode(x) # [bs, len(_acoustic_books) + 1, T]
y = model.decode(codes) # audio signal 24KHz 
soundfile.write('reconstruct.wav', y[0, 0, :].cpu().numpy(), 24000)
model._flush()  # For encode()/decode() for different batch size
Downloads last month
1,431
Safetensors
Model size
76.1M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for ivao0/voc

Base model

kyutai/mimi
Finetuned
(1)
this model