File size: 696 Bytes
1df74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from einops import rearrange
from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ


def quantize(
    quantizer: GroupedResidualFSQ,
    audio_latents: torch.Tensor,  # (batch_size, audio_len, audio_dim=1024)
) -> tuple[torch.Tensor, torch.Tensor]:
    # feat shape (batch_size, audio_len, audio_dim)
    # ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R)
    # num_vq=GFSQ.G*GFSQ.R
    feat, ind = quantizer(audio_latents)
    audio_quantized_latents = feat  # (batch_size, audio_len, audio_dim)
    audio_input_ids = rearrange(  # (batch_size, audio_len, num_vq)
        ind,
        "g b t r ->b t (g r)",
    )
    return audio_quantized_latents, audio_input_ids