|
import gradio as gr |
|
from miditok import REMI |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from reformer_pytorch import ReformerLM, Reformer |
|
from axial_positional_embedding import AxialPositionalEmbedding |
|
import math |
|
import os |
|
import subprocess |
|
import pytube |
|
import binascii |
|
|
|
import torch |
|
from torch import nn |
|
import torchaudio |
|
|
|
yt_dir = "./yt_dir" |
|
midi_dir = "./midi_dir" |
|
os.makedirs(yt_dir, exist_ok=True) |
|
os.makedirs(midi_dir, exist_ok=True) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
class ReformerEncoderDecoderConfig(PretrainedConfig): |
|
def __init__(self, |
|
vocab_size=50265, |
|
d_model=128, |
|
num_heads=8, |
|
encoder_layers=6, |
|
decoder_layers=6, |
|
encoder_max_seq_len=6144, |
|
decoder_max_seq_len=4096, |
|
encoder_axial_position_shape=(96, 64), |
|
decoder_axial_position_shape=(64, 64), |
|
pad_token_id=0, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
**kwargs): |
|
self.vocab_size = vocab_size |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.encoder_layers = encoder_layers |
|
self.decoder_layers = decoder_layers |
|
self.encoder_max_seq_len = encoder_max_seq_len |
|
self.decoder_max_seq_len = decoder_max_seq_len |
|
self.encoder_axial_position_shape = encoder_axial_position_shape |
|
self.decoder_axial_position_shape = decoder_axial_position_shape |
|
super().__init__(**kwargs) |
|
self.pad_token_id = pad_token_id |
|
self.bos_token_id = bos_token_id |
|
self.eos_token_id = eos_token_id |
|
|
|
|
|
class ReformerEncoderDecoder(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.pad_token_id = config.pad_token_id |
|
self.bos_token_id = config.bos_token_id |
|
self.eos_token_id = config.eos_token_id |
|
|
|
self.encoder = Reformer( |
|
dim=config.d_model, |
|
depth=config.encoder_layers, |
|
heads=config.num_heads, |
|
) |
|
|
|
self.decoder = ReformerLM( |
|
dim=config.d_model, |
|
depth=config.decoder_layers, |
|
heads=config.num_heads, |
|
max_seq_len=config.decoder_max_seq_len, |
|
num_tokens=config.vocab_size, |
|
axial_position_emb=True, |
|
axial_position_shape=config.decoder_axial_position_shape, |
|
causal=True |
|
) |
|
|
|
self.position_embedding = AxialPositionalEmbedding( |
|
config.d_model, |
|
axial_shape=config.encoder_axial_position_shape |
|
) |
|
|
|
|
|
def pad_to_multiple(self, tensor, seq_len, multiple, dim=-1): |
|
m = seq_len / multiple |
|
if m.is_integer(): |
|
return tensor |
|
|
|
remainder = math.ceil(m) * multiple - seq_len |
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
return nn.functional.pad(tensor, (*pad_offset, 0, remainder), value=self.pad_token_id) |
|
|
|
|
|
|
|
def auto_paddding(self, input_ids, pad_dim, bucket_size, num_mem_kv, full_attn_thres, keys=None, input_mask=None, input_attn_mask=None): |
|
device = input_ids.device |
|
|
|
batch_size, t = input_ids.shape[:2] |
|
|
|
keys_len = 0 if keys is None else keys.shape[1] |
|
seq_len = t + num_mem_kv + keys_len |
|
|
|
|
|
if seq_len > full_attn_thres: |
|
if input_mask is None: |
|
input_mask = torch.full((batch_size, t), True, dtype=torch.bool, device=device) |
|
|
|
input_ids = self.pad_to_multiple(input_ids, seq_len, bucket_size * 2, dim=pad_dim) |
|
|
|
if input_mask is not None: |
|
input_mask = nn.functional.pad(input_mask, (0, input_ids.shape[1] - input_mask.shape[1]), value=False) |
|
|
|
if input_attn_mask is not None: |
|
offset = input_ids.shape[1] - input_attn_mask.shape[1] |
|
input_attn_mask = nn.functional.pad(input_attn_mask, (0, offset, 0, offset), value=False) |
|
|
|
return input_ids, input_mask, input_attn_mask |
|
|
|
|
|
def shift_tokens_right(self, input_ids): |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
shifted_input_ids[:, 0] = self.eos_token_id |
|
|
|
if self.pad_token_id is None: |
|
raise ValueError("config.pad_token_id has to be defined.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
def forward(self, inputs_embeds, attention_mask=None, decoder_input=None, labels=None): |
|
if decoder_input is None: |
|
decoder_input = self.shift_tokens_right(labels) |
|
|
|
|
|
encoder_input = inputs_embeds + self.position_embedding(inputs_embeds) |
|
|
|
encoder_output = self.encoder(encoder_input, input_mask=attention_mask.bool()) |
|
|
|
|
|
decoder_output = self.decoder(decoder_input, keys=encoder_output) |
|
|
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(decoder_output.view(-1, self.config.vocab_size), labels.view(-1)) |
|
return {"loss": masked_lm_loss, "logits": decoder_output} |
|
|
|
return {"logits": decoder_output} |
|
|
|
|
|
@torch.no_grad() |
|
def generate(self, inputs_embeds, attention_mask=None, max_length=4096, temperature=1.0, top_k=50, top_p=1): |
|
is_training = self.training |
|
device = inputs_embeds.device |
|
|
|
|
|
pad_dim = -1 |
|
bucket_size = self.decoder.reformer.bucket_size |
|
num_mem_kv = self.decoder.reformer.num_mem_kv |
|
full_attn_thres = self.decoder.reformer.full_attn_thres |
|
|
|
self.eval() |
|
|
|
|
|
encoder_input = inputs_embeds + self.position_embedding(inputs_embeds) |
|
|
|
encoder_keys = self.encoder(encoder_input, input_mask=attention_mask.bool()) |
|
|
|
|
|
generated = torch.tensor([self.bos_token_id], device=device).unsqueeze(0) |
|
|
|
decoder_mask = torch.full_like(generated, True, dtype=torch.bool, device=device) |
|
|
|
for _ in range(max_length): |
|
generated = generated[:, -self.config.decoder_max_seq_len:] |
|
decoder_mask = decoder_mask[:, -self.config.decoder_max_seq_len:] |
|
|
|
generated, decoder_mask, _ = self.auto_paddding(generated, |
|
pad_dim, |
|
bucket_size, |
|
num_mem_kv, |
|
full_attn_thres, |
|
keys=encoder_keys, |
|
input_mask=decoder_mask) |
|
|
|
logits = self.decoder(generated, input_mask=decoder_mask, keys=encoder_keys)[:, -1, :] / temperature |
|
|
|
if top_k > 0: |
|
top_k_values, top_k_indices = torch.topk(logits, top_k) |
|
filtered_logits = torch.full_like(logits, -float('Inf')) |
|
logits = filtered_logits.scatter(1, top_k_indices, top_k_values) |
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(nn.functional.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
sorted_indices_to_remove[:, 0] = 0 |
|
|
|
sorted_logits[sorted_indices_to_remove] = -float('Inf') |
|
logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) |
|
|
|
probs = nn.functional.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
generated = torch.cat([generated, next_token], dim=-1) |
|
|
|
if next_token == self.eos_token_id: |
|
break |
|
|
|
self.train(is_training) |
|
return generated |
|
|
|
|
|
|
|
|
|
model = ReformerEncoderDecoder(ReformerEncoderDecoderConfig()).to(device) |
|
model.load_state_dict(torch.load("model.pth")) |
|
tokenizer = REMI(params="tokenizer.json") |
|
|
|
|
|
|
|
class ArrangerEmbedding(nn.Module): |
|
def __init__(self, arranger_ids=256, hidden_size=128): |
|
super().__init__() |
|
self.embeddings = nn.Embedding(arranger_ids, hidden_size) |
|
|
|
def forward(self, arranger_id, mel_db): |
|
return torch.cat([self.embeddings(arranger_id), mel_db], dim=-2) |
|
|
|
|
|
def initialize_model(model_path): |
|
RedConfig = ReformerEncoderDecoderConfig() |
|
|
|
model = ReformerEncoderDecoder(RedConfig).cuda() |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
return model |
|
|
|
|
|
def load_input(song_path, arranger_id): |
|
waveform, sr = torchaudio.load(song_path) |
|
waveform = torchaudio.transforms.Resample(sr, 22050)(waveform) |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=4096, hop_length=1024, n_mels=128) |
|
mel = mel_transform(waveform) |
|
mel_db = torchaudio.transforms.AmplitudeToDB()(mel) |
|
|
|
mel_shape = mel_db.shape |
|
mel_db = mel_db.reshape(mel_shape[0], mel_shape[2], mel_shape[1]) |
|
|
|
if mel_db.shape[2] > 6144: |
|
mel_db = mel_db[:, :6144] |
|
|
|
num_pad = 6144 - mel_db.shape[1] - 1 |
|
mel_padded = torch.cat([mel_db, torch.zeros((1, num_pad, mel_db.shape[2]))], dim=1) |
|
|
|
embbeding = ArrangerEmbedding() |
|
input_embed = embbeding(torch.tensor([[int(arranger_id)]]), mel_padded) |
|
attention_mask = torch.cat([torch.ones(mel_db.shape[:2], dtype=torch.int32), torch.zeros((mel_db.shape[0], num_pad + 1))], dim=1) |
|
|
|
return input_embed, attention_mask |
|
|
|
|
|
def download_piano(youtube_link): |
|
yt = pytube.YouTube(youtube_link) |
|
download_path = os.path.join(yt_dir, yt.title + ".mp4") |
|
yt.streams.filter(only_audio=True).first().download(download_path) |
|
|
|
|
|
mp3_path = str(download_path).replace(".mp4", ".mp3") |
|
result = subprocess.run([ |
|
"ffmpeg", |
|
"-i", download_path, |
|
mp3_path |
|
]) |
|
|
|
if result.returncode != 0: |
|
raise Exception("Failed to convert to mp3") |
|
|
|
return mp3_path |
|
|
|
|
|
def inference(yt_link, arranger_id): |
|
song_path = download_piano(yt_link) |
|
input_embed, attention_mask = load_input(song_path, arranger_id) |
|
generated = model.generate(input_embed.cuda(), attention_mask.cuda()) |
|
return post_process(generated) |
|
|
|
|
|
def post_process(generated): |
|
midi = tokenizer.decode(generated.argmax(dim=-1).cpu()) |
|
|
|
|
|
output_midi_path = os.path.join(midi_dir, f"{binascii.hexlify(os.urandom(8)).decode()}.mid") |
|
midi.dump_midi(os.path.join(midi_dir, output_midi_path)) |
|
|
|
return output_midi_path |
|
|
|
|
|
app = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Textbox(label="Youtube Link"), |
|
gr.Dropdown([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], label="Arranger ID", value=1) |
|
], |
|
outputs=gr.File(label="MIDI File") |
|
) |
|
|
|
app.launch() |
|
|