import torch class AttentionMLP(torch.nn.Module): def __init__(self, input_dim, hidden_dim): super(AttentionMLP, self).__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, 1, bias=False), ) def forward(self, x): x = self.layers(x) att_w = torch.nn.functional.softmax(x, dim=2) return att_w class Discrete_EmbeddingLayer(torch.nn.Module): """This class handles embedding layers for discrete tokens. Arguments --------- num_codebooks: int , number of codebooks of the tokenizer. vocab_size : int, size of the dictionary of embeddings emb_dim: int , the size of each embedding vector pad_index: int (default: 0), If specified, the entries at padding_idx do not contribute to the gradient. init: boolean (default: False): If set to True, init the embedding with the tokenizer embedding otherwise init randomly. freeze: boolean (default: False) If True, the embedding is frozen. If False, the model will be trained alongside with the rest of the pipeline. Example ------- >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec >>> model_hub = "facebook/encodec_24khz" >>> save_path = "savedir" >>> model = Encodec(model_hub, save_path) >>> audio = torch.randn(4, 1000) >>> length = torch.tensor([1.0, .5, .75, 1.0]) >>> tokens, emb = model.encode(audio, length) >>> print(tokens.shape) torch.Size([4, 4, 2]) >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) >>> in_emb = emb(tokens) >>> print(in_emb.shape) torch.Size([4, 4, 2, 1024]) """ def __init__( self, num_codebooks, vocab_size, emb_dim, pad_index=0, init=False, freeze=False, ): super(Discrete_EmbeddingLayer, self).__init__() self.vocab_size = vocab_size self.num_codebooks = num_codebooks self.freeze = freeze self.embedding = torch.nn.Embedding( num_codebooks * vocab_size, emb_dim ).requires_grad_(not self.freeze) self.init = init def init_embedding(self, weights): with torch.no_grad(): self.embedding.weight = torch.nn.Parameter(weights) def forward(self, in_tokens): """Computes the embedding for discrete tokens. a sample. Arguments --------- in_tokens : torch.Tensor A (Batch x Time x num_codebooks) audio sample Returns ------- in_embs : torch.Tensor """ with torch.set_grad_enabled(not self.freeze): # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size in_tokens += torch.arange( 0, self.num_codebooks * self.vocab_size, self.vocab_size, device=in_tokens.device, ) # Forward Pass to embedding and in_embs = self.embedding(in_tokens) return in_embs