|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
from typing import Dict, Iterator, List, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchmetrics.classification import MulticlassAccuracy |
|
from utils.util import make_pad_mask |
|
from utils.topk_sampling import topk_sampling |
|
from modules.general import Transpose |
|
from modules.encoder import TokenEmbedding |
|
from modules.general import PromptedFeatures |
|
from modules.transformer import SinePositionalEmbedding |
|
from modules.norms import AdaptiveLayerNorm, LayerNorm |
|
from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
|
class VALLE(nn.Module): |
|
def __init__( |
|
self, |
|
cfg, |
|
decoder_cls=TransformerEncoder, |
|
decoder_layer_cls=TransformerEncoderLayer, |
|
): |
|
super().__init__() |
|
decoder_dim = cfg.decoder_dim |
|
nhead = cfg.nhead |
|
nar_scale_factor = cfg.nar_scale_factor |
|
num_quantizers = cfg.num_quantizers |
|
num_decoder_layers = cfg.num_decoder_layers |
|
nar_decoder_dim = int(decoder_dim * nar_scale_factor) |
|
|
|
self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num) |
|
self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num) |
|
|
|
self.ar_audio_prepend_bos = cfg.prepend_bos |
|
self.ar_audio_embedding = TokenEmbedding( |
|
decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos) |
|
) |
|
self.audio_token_num = cfg.audio_token_num |
|
|
|
|
|
if cfg.add_prenet: |
|
self.ar_text_prenet = nn.Sequential( |
|
Transpose(), |
|
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), |
|
nn.BatchNorm1d(decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), |
|
nn.BatchNorm1d(decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), |
|
nn.BatchNorm1d(decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
Transpose(), |
|
nn.Linear(decoder_dim, decoder_dim), |
|
) |
|
|
|
self.ar_audio_prenet = nn.Sequential( |
|
nn.Linear(decoder_dim, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
nn.Linear(256, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
nn.Linear(256, decoder_dim), |
|
) |
|
else: |
|
self.ar_text_prenet = nn.Identity() |
|
self.ar_audio_prenet = nn.Identity() |
|
|
|
self.ar_text_position = SinePositionalEmbedding( |
|
decoder_dim, |
|
dropout=0.1, |
|
scale=False, |
|
alpha=True, |
|
) |
|
self.ar_audio_position = SinePositionalEmbedding( |
|
decoder_dim, |
|
dropout=0.1, |
|
scale=False, |
|
alpha=True, |
|
) |
|
|
|
self.ar_decoder = decoder_cls( |
|
decoder_layer_cls( |
|
decoder_dim, |
|
nhead, |
|
dim_feedforward=decoder_dim * 4, |
|
dropout=0.1, |
|
batch_first=True, |
|
norm_first=cfg.norm_first, |
|
), |
|
num_layers=num_decoder_layers, |
|
norm=LayerNorm(decoder_dim) if cfg.norm_first else None, |
|
) |
|
self.ar_predict_layer = nn.Linear( |
|
decoder_dim, cfg.audio_token_num + 1, bias=False |
|
) |
|
|
|
self.ar_accuracy_metric = MulticlassAccuracy( |
|
cfg.audio_token_num + 1, |
|
top_k=10, |
|
average="micro", |
|
multidim_average="global", |
|
ignore_index=cfg.audio_token_num, |
|
) |
|
|
|
self.rng = random.Random(0) |
|
self.num_heads = nhead |
|
self.prefix_mode = cfg.prefix_mode |
|
self.num_quantizers = num_quantizers |
|
|
|
assert num_quantizers >= 1 |
|
if num_quantizers > 1: |
|
self.nar_audio_embeddings = nn.ModuleList( |
|
[ |
|
TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1) |
|
] |
|
+ [ |
|
TokenEmbedding(nar_decoder_dim, cfg.audio_token_num) |
|
for i in range(num_quantizers - 1) |
|
] |
|
) |
|
|
|
if cfg.add_prenet: |
|
self.nar_text_prenet = nn.Sequential( |
|
Transpose(), |
|
nn.Conv1d( |
|
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" |
|
), |
|
nn.BatchNorm1d(nar_decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Conv1d( |
|
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" |
|
), |
|
nn.BatchNorm1d(nar_decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Conv1d( |
|
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" |
|
), |
|
nn.BatchNorm1d(nar_decoder_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
Transpose(), |
|
nn.Linear(nar_decoder_dim, nar_decoder_dim), |
|
) |
|
self.nar_audio_prenet = nn.Sequential( |
|
nn.Linear(nar_decoder_dim, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
nn.Linear(256, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
nn.Linear(256, nar_decoder_dim), |
|
) |
|
else: |
|
self.nar_text_prenet = nn.Identity() |
|
self.nar_audio_prenet = nn.Identity() |
|
|
|
self.nar_text_position = SinePositionalEmbedding( |
|
nar_decoder_dim, |
|
dropout=0.0, |
|
scale=False, |
|
alpha=False, |
|
) |
|
self.nar_audio_position = SinePositionalEmbedding( |
|
nar_decoder_dim, |
|
dropout=0.1, |
|
scale=False, |
|
alpha=False, |
|
) |
|
|
|
self.nar_decoder = decoder_cls( |
|
decoder_layer_cls( |
|
nar_decoder_dim, |
|
int(nhead * nar_scale_factor), |
|
dim_feedforward=nar_decoder_dim * 4, |
|
dropout=0.1, |
|
batch_first=True, |
|
norm_first=cfg.norm_first, |
|
adaptive_layer_norm=True, |
|
), |
|
num_layers=int(num_decoder_layers * nar_scale_factor), |
|
norm=( |
|
AdaptiveLayerNorm( |
|
nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim) |
|
) |
|
if cfg.norm_first |
|
else None |
|
), |
|
) |
|
self.nar_predict_layers = nn.ModuleList( |
|
[ |
|
nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False) |
|
for i in range(num_quantizers - 1) |
|
] |
|
) |
|
self.nar_stage_embeddings = nn.ModuleList( |
|
[TokenEmbedding(nar_decoder_dim, 1) for i in range(num_quantizers - 1)] |
|
) |
|
|
|
if cfg.share_embedding: |
|
for j in range(0, num_quantizers - 2): |
|
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ |
|
j + 2 |
|
].weight |
|
|
|
self.nar_accuracy_metric = MulticlassAccuracy( |
|
cfg.audio_token_num + 1, |
|
top_k=10, |
|
average="micro", |
|
multidim_average="global", |
|
ignore_index=cfg.audio_token_num, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
x_lens: torch.Tensor, |
|
y: Union[torch.Tensor, PromptedFeatures], |
|
y_lens: Union[torch.Tensor, PromptedFeatures], |
|
reduction: str = "sum", |
|
train_stage: int = 0, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: |
|
""" |
|
Args: |
|
x: |
|
A 2-D tensor of shape (N, S). |
|
x_lens: |
|
A 1-D tensor of shape (N,). It contains the number of tokens in `x` |
|
before padding. |
|
y: |
|
A 3-D tensor of shape (N, T, 8). |
|
y_lens: |
|
A 1-D tensor of shape (N,). It contains the number of tokens in `x` |
|
before padding. |
|
train_stage: |
|
0: AR & NAR modules, 1: AR modules, 2: NAR modules |
|
Returns: |
|
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. |
|
""" |
|
assert x.ndim == 2, x.shape |
|
assert x_lens.ndim == 1, x_lens.shape |
|
|
|
y_prompts_codes = None |
|
if isinstance(y, PromptedFeatures): |
|
y_prompts_codes, y = y.data |
|
prompts_len, y_lens = y_lens.data |
|
assert prompts_len.min() == prompts_len.max() |
|
assert self.prefix_mode == 4 |
|
y_prompts_codes = y_prompts_codes.type(torch.int64) |
|
|
|
assert y.ndim == 3, y.shape |
|
assert y_lens.ndim == 1, y_lens.shape |
|
|
|
x_mask = make_pad_mask(x_lens).to(x.device) |
|
y_mask = make_pad_mask(y_lens).to(y.device) |
|
y_mask_int = y_mask.type(torch.int64) |
|
|
|
text = x |
|
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) |
|
|
|
y, targets = self.pad_y_eos( |
|
codes[..., 0], y_mask_int, eos_id=self.audio_token_num |
|
) |
|
self.y_mask_int = y_mask_int |
|
|
|
metrics = {} |
|
total_loss = 0.0 |
|
|
|
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) |
|
if self.ar_audio_prepend_bos: |
|
ar_xy_padding_mask = torch.concat( |
|
[x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 |
|
) |
|
else: |
|
ar_xy_padding_mask = xy_padding_mask |
|
self.xy_padding_mask = xy_padding_mask |
|
self.ar_xy_padding_mask = ar_xy_padding_mask |
|
|
|
|
|
if train_stage in [0, 1]: |
|
ar_loss, ar_metrics = self._forward_ar_decoder( |
|
text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction |
|
) |
|
total_loss += ar_loss |
|
metrics["AR_Top100Acc"] = ar_metrics |
|
|
|
|
|
if self.ar_audio_prepend_bos: |
|
y = y[:, 1:] |
|
|
|
if self.num_quantizers > 1 and train_stage in [0, 2]: |
|
nar_loss, nar_metrics = self._forward_nar_decoder( |
|
text, |
|
x_lens, |
|
y, |
|
y_lens, |
|
codes, |
|
y_prompts_codes, |
|
x_mask, |
|
y_mask, |
|
reduction, |
|
) |
|
total_loss += nar_loss |
|
metrics["NAR_Top100Acc"] = nar_metrics |
|
|
|
if train_stage == 0: |
|
total_loss = total_loss / 2.0 |
|
|
|
return total_loss, metrics |
|
|
|
def _forward_ar_decoder( |
|
self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction |
|
): |
|
x = self.ar_text_embedding(x) |
|
x = self.ar_text_prenet(x) |
|
x = self.ar_text_position(x) |
|
|
|
y_len = y_lens.max() + int(self.ar_audio_prepend_bos) |
|
|
|
x_attn_mask = F.pad( |
|
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), |
|
(0, y_len), |
|
value=True, |
|
) |
|
y_attn_mask = F.pad( |
|
torch.triu( |
|
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), |
|
diagonal=1, |
|
), |
|
(x_len, 0), |
|
value=False, |
|
) |
|
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) |
|
|
|
bsz, src_len = x.shape[0], x_len + y_len |
|
_xy_padding_mask = ( |
|
self.ar_xy_padding_mask.view(bsz, 1, 1, src_len) |
|
.expand(-1, self.num_heads, -1, -1) |
|
.reshape(bsz * self.num_heads, 1, src_len) |
|
) |
|
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) |
|
|
|
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) |
|
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) |
|
xy_attn_mask = new_attn_mask |
|
|
|
y_emb = self.ar_audio_embedding(y) |
|
y_emb = self.ar_audio_prenet(y_emb) |
|
y_pos = self.ar_audio_position(y_emb) |
|
|
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
xy_dec, _ = self.ar_decoder( |
|
(xy_pos, None), |
|
mask=xy_attn_mask, |
|
) |
|
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) |
|
ar_loss = F.cross_entropy(logits, targets, reduction=reduction) |
|
|
|
ar_metrics = self.ar_accuracy_metric( |
|
logits.detach(), targets |
|
).item() * y_lens.sum().type(torch.float32) |
|
|
|
return ar_loss, ar_metrics |
|
|
|
def _forward_nar_decoder( |
|
self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction |
|
): |
|
num_nar_layers = self.num_quantizers - 1 |
|
nar_stage = self.rng.choices( |
|
[_k for _k in range(1, self.num_quantizers)], |
|
weights=[1.0 / num_nar_layers] * num_nar_layers, |
|
k=1, |
|
)[0] |
|
|
|
x = self.nar_text_embedding(x) |
|
x = self.nar_text_prenet(x) |
|
x = self.nar_text_position(x) |
|
|
|
y_emb, prefix_len = self._prepare_prompts( |
|
y, y_lens, codes, nar_stage, y_prompts_codes |
|
) |
|
|
|
y_len = y_lens.max() |
|
targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int |
|
if self.prefix_mode in [2, 4]: |
|
xy_padding_mask = torch.concat( |
|
[ |
|
x_mask, |
|
F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), |
|
], |
|
dim=1, |
|
) |
|
elif self.prefix_mode == 1: |
|
targets = targets[:, prefix_len:] |
|
|
|
y_pos = self.nar_audio_prenet(y_emb) |
|
y_pos = self.nar_audio_position(y_pos) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
xy_dec, _ = self.nar_decoder( |
|
(xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), |
|
src_key_padding_mask=self.xy_padding_mask, |
|
) |
|
xy_dec = xy_dec[:, x_lens.max() + prefix_len :] |
|
if self.prefix_mode == 4: |
|
prefix_len = 0 |
|
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) |
|
|
|
total_length = (y_lens).sum().type(torch.float32) |
|
nar_loss = F.cross_entropy( |
|
logits, |
|
targets, |
|
ignore_index=self.audio_token_num, |
|
reduction=reduction, |
|
) * (total_length / (total_length - prefix_len * x.shape[0])) |
|
nar_metrics = ( |
|
self.nar_accuracy_metric( |
|
F.pad( |
|
logits.detach(), |
|
(0, 0, 0, 1, 0, 0), |
|
value=logits.min().cpu().item(), |
|
), |
|
targets, |
|
).item() |
|
* total_length |
|
) |
|
return nar_loss, nar_metrics |
|
|
|
def inference( |
|
self, |
|
x: torch.Tensor, |
|
x_lens: torch.Tensor, |
|
y: torch.Tensor, |
|
enroll_x_lens: torch.Tensor, |
|
top_k: int = -100, |
|
temperature: float = 1.0, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: |
|
A 2-D tensor of shape (1, S). |
|
x_lens: |
|
A 1-D tensor of shape (1,). It contains the number of tokens in `x` |
|
before padding. |
|
y: |
|
A 3-D tensor of shape (1, T, 8). |
|
top_k: (`optional`) int |
|
The number of highest probability tokens to keep for top-k-filtering. Default to -100. |
|
temperature: (`optional`) float |
|
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. |
|
Returns: |
|
Return the predicted audio code matrix. |
|
""" |
|
assert x.ndim == 2, x.shape |
|
assert x_lens.ndim == 1, x_lens.shape |
|
assert y.ndim == 3, y.shape |
|
assert y.shape[0] == 1, y.shape |
|
|
|
assert torch.all(x_lens > 0) |
|
|
|
text = x |
|
x = self.ar_text_embedding(text) |
|
x = self.ar_text_prenet(x) |
|
x = self.ar_text_position(x) |
|
|
|
text_len = x_lens.max() |
|
prompts = y |
|
prefix_len = y.shape[1] |
|
|
|
|
|
y = prompts[..., 0] |
|
if self.ar_audio_prepend_bos: |
|
y = F.pad(y, (1, 0), value=self.audio_token_num + 1) |
|
|
|
x_len = x_lens.max() |
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) |
|
|
|
while True: |
|
y_emb = self.ar_audio_embedding(y) |
|
y_emb = self.ar_audio_prenet(y_emb) |
|
y_pos = self.ar_audio_position(y_emb) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
y_len = y.shape[1] |
|
x_attn_mask_pad = F.pad( |
|
x_attn_mask, |
|
(0, y_len), |
|
value=True, |
|
) |
|
y_attn_mask = F.pad( |
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), |
|
(x_len, 0), |
|
value=False, |
|
) |
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( |
|
y.device |
|
) |
|
|
|
xy_dec, _ = self.ar_decoder( |
|
(xy_pos, None), |
|
mask=xy_attn_mask, |
|
) |
|
logits = self.ar_predict_layer(xy_dec[:, -1]) |
|
samples = topk_sampling( |
|
logits, top_k=top_k, top_p=1.0, temperature=temperature |
|
) |
|
|
|
if ( |
|
torch.argmax(logits, dim=-1)[0] == self.audio_token_num |
|
or samples[0, 0] == self.audio_token_num |
|
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 |
|
): |
|
if prompts.shape[1] == y.shape[1]: |
|
raise SyntaxError("well trained model shouldn't reach here.") |
|
|
|
break |
|
|
|
y = torch.concat([y, samples], dim=1) |
|
|
|
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] |
|
if self.num_quantizers == 1: |
|
return torch.stack(codes, dim=-1) |
|
|
|
|
|
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) |
|
|
|
if self.prefix_mode in [2, 4]: |
|
enrolled_len = enroll_x_lens.max().item() |
|
|
|
text = torch.concat( |
|
[ |
|
text[:, :1], |
|
text[:, enrolled_len - 1 :], |
|
], |
|
dim=1, |
|
) |
|
text_len = text_len - (enrolled_len - 2) |
|
assert text.shape[0] == 1 |
|
|
|
x = self.nar_text_embedding(text) |
|
x = self.nar_text_prenet(x) |
|
x = self.nar_text_position(x) |
|
|
|
if self.prefix_mode == 0: |
|
for i, (predict_layer, embedding_layer) in enumerate( |
|
zip( |
|
self.nar_predict_layers, |
|
self.nar_audio_embeddings[1:], |
|
) |
|
): |
|
y_pos = self.nar_audio_prenet(y_emb) |
|
y_pos = self.nar_audio_position(y_pos) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
xy_dec, _ = self.nar_decoder( |
|
(xy_pos, self.nar_stage_embeddings[i].weight) |
|
) |
|
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
|
|
|
samples = torch.argmax(logits, dim=-1) |
|
codes.append(samples) |
|
|
|
if i < self.num_quantizers - 2: |
|
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) |
|
y_emb[:, prefix_len:] += embedding_layer(samples) |
|
else: |
|
for j in range(1, self.num_quantizers): |
|
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) |
|
|
|
for i, (predict_layer, embedding_layer) in enumerate( |
|
zip( |
|
self.nar_predict_layers, |
|
self.nar_audio_embeddings[1:], |
|
) |
|
): |
|
y_pos = self.nar_audio_prenet(y_emb) |
|
y_pos = self.nar_audio_position(y_pos) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
xy_dec, _ = self.nar_decoder( |
|
(xy_pos, self.nar_stage_embeddings[i].weight) |
|
) |
|
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
|
|
|
samples = torch.argmax(logits, dim=-1) |
|
codes.append(samples) |
|
|
|
if i < self.num_quantizers - 2: |
|
y_emb[:, prefix_len:] += embedding_layer(samples) |
|
|
|
assert len(codes) == self.num_quantizers |
|
return torch.stack(codes, dim=-1) |
|
|
|
def continual( |
|
self, |
|
x: torch.Tensor, |
|
x_lens: torch.Tensor, |
|
y: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: |
|
A 2-D tensor of shape (1, S). |
|
x_lens: |
|
A 1-D tensor of shape (1,). It contains the number of tokens in `x` |
|
before padding. |
|
y: |
|
A 3-D tensor of shape (1, T, 8). |
|
Returns: |
|
Return the predicted audio code matrix. |
|
""" |
|
assert x.ndim == 2, x.shape |
|
assert x_lens.ndim == 1, x_lens.shape |
|
assert y.ndim == 3, y.shape |
|
assert y.shape[0] == 1, y.shape |
|
|
|
assert torch.all(x_lens > 0) |
|
assert self.num_quantizers == 8 |
|
|
|
text = x |
|
x = self.ar_text_embedding(text) |
|
x = self.ar_text_prenet(x) |
|
x = self.ar_text_position(x) |
|
|
|
text_len = x_lens.max() |
|
|
|
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) |
|
|
|
|
|
prompts = y[:, :prefix_len] |
|
|
|
codes = [y[:, prefix_len:, 0]] |
|
|
|
x = self.nar_text_embedding(text) |
|
x = self.nar_text_prenet(x) |
|
x = self.nar_text_position(x) |
|
|
|
y_emb = self.nar_audio_embeddings[0](y[..., 0]) |
|
|
|
if self.prefix_mode == 0: |
|
for i, (predict_layer, embedding_layer) in enumerate( |
|
zip( |
|
self.nar_predict_layers, |
|
self.nar_audio_embeddings[1:], |
|
) |
|
): |
|
y_pos = self.nar_audio_position(y_emb) |
|
y_pos = self.nar_audio_prenet(y_pos) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
xy_dec, _ = self.nar_decoder( |
|
(xy_pos, self.nar_stage_embeddings[i].weight) |
|
) |
|
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
|
|
|
samples = torch.argmax(logits, dim=-1) |
|
codes.append(samples) |
|
|
|
if i < 6: |
|
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) |
|
y_emb[:, prefix_len:] += embedding_layer(samples) |
|
else: |
|
for j in range(1, 8): |
|
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) |
|
|
|
for i, (predict_layer, embedding_layer) in enumerate( |
|
zip( |
|
self.nar_predict_layers, |
|
self.nar_audio_embeddings[1:], |
|
) |
|
): |
|
y_pos = self.nar_audio_prenet(y_emb) |
|
y_pos = self.nar_audio_position(y_pos) |
|
xy_pos = torch.concat([x, y_pos], dim=1) |
|
|
|
xy_dec, _ = self.nar_decoder( |
|
(xy_pos, self.nar_stage_embeddings[i].weight) |
|
) |
|
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) |
|
|
|
samples = torch.argmax(logits, dim=-1) |
|
codes.append(samples) |
|
|
|
if i < 6: |
|
y_emb[:, prefix_len:] += embedding_layer(samples) |
|
|
|
assert len(codes) == 8 |
|
return torch.stack(codes, dim=-1) |
|
|
|
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: |
|
assert stage > 0 |
|
if stage == 1: |
|
for name, param in self.named_parameters(): |
|
if name.startswith("ar_"): |
|
yield param |
|
|
|
if stage == 2: |
|
for name, param in self.named_parameters(): |
|
if name.startswith("nar_"): |
|
yield param |
|
|
|
def stage_named_parameters( |
|
self, stage: int = 1 |
|
) -> Iterator[Tuple[str, nn.Parameter]]: |
|
assert stage > 0 |
|
if stage == 1: |
|
for pair in self.named_parameters(): |
|
if pair[0].startswith("ar_"): |
|
yield pair |
|
|
|
if stage == 2: |
|
for pair in self.named_parameters(): |
|
if pair[0].startswith("nar_"): |
|
yield pair |
|
|
|
def pad_y_eos(self, y, y_mask_int, eos_id): |
|
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( |
|
y_mask_int, (0, 1), value=1 |
|
) |
|
if self.ar_audio_prepend_bos: |
|
return ( |
|
F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1), |
|
targets, |
|
) |
|
|
|
return targets[:, :-1], targets[:, 1:] |
|
|
|
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): |
|
|
|
|
|
|
|
if self.prefix_mode == 0: |
|
|
|
prefix_len = 0 |
|
y_emb = self.nar_audio_embeddings[0](y) |
|
for j in range(1, nar_stage): |
|
|
|
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) |
|
elif self.prefix_mode == 1: |
|
|
|
int_low = (0.25 * y_lens.min()).type(torch.int64).item() |
|
prefix_len = torch.randint(int_low, int_low * 2, size=()).item() |
|
prefix_len = min(prefix_len, 225) |
|
|
|
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) |
|
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) |
|
for j in range(1, self.num_quantizers): |
|
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) |
|
if j < nar_stage: |
|
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) |
|
y_emb = torch.concat([y_prompts, y_emb], axis=1) |
|
elif self.prefix_mode in [2, 4]: |
|
if self.prefix_mode == 2: |
|
|
|
prefix_len = min(225, int(0.25 * y_lens.min().item())) |
|
|
|
y_prompts_codes = [] |
|
for b in range(codes.shape[0]): |
|
start = self.rng.randint(0, y_lens[b].item() - prefix_len) |
|
y_prompts_codes.append( |
|
torch.clone(codes[b, start : start + prefix_len]) |
|
) |
|
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS |
|
y_prompts_codes = torch.stack(y_prompts_codes, dim=0) |
|
else: |
|
prefix_len = y_prompts_codes.shape[1] |
|
|
|
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) |
|
y_emb = self.nar_audio_embeddings[0](y) |
|
for j in range(1, self.num_quantizers): |
|
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) |
|
if j < nar_stage: |
|
y_emb += self.nar_audio_embeddings[j](codes[..., j]) |
|
y_emb = torch.concat([y_prompts, y_emb], axis=1) |
|
else: |
|
raise ValueError |
|
|
|
return y_emb, prefix_len |
|
|