jukwi-vqvae / app.py
vovahimself's picture
first try
631e673
raw
history blame
3.21 kB
# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI
from transformers import JukeboxVQVAE
import gradio as gr
import torch as t
model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']
if 'google.colab' in sys.modules:
cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"}
# Connect to your Google Drive
from google.colab import drive
drive.mount('/content/drive')
else:
cache_path = '~/.cache/'
class Convert:
class TokenList:
def to_tokens_file(tokens_list):
# temporary random file name
filename = f"tmp/{t.randint(0, 1000000)}.jt"
t.save(validate_tokens_list(tokens_list), filename)
return filename
def to_audio(tokens_list):
return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)
# TODO: Implement converting other levels besides 2
class TokensFile:
def to_tokens_list(file):
return validate_tokens_list(t.load(file))
def to_audio(file):
return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))
class Audio:
def to_tokens_list(audio):
return model.encode(audio.unsqueeze(0), start_level=2)
# (TODO: Generated by copilot, check if it works)
def to_tokens_file(audio):
return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))
def init():
global model
model = JukeboxVQVAE.from_pretrained(
model_id,
device_map = "auto",
torch_dtype = t.float16,
cache_dir = f"{cache_path}/jukebox/models"
)
def validate_tokens_list(tokens_list):
# Make sure that:
# - tokens_list is a list of exactly 3 torch tensors
assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors"
# - each has the same number of dimensions
assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions"
# - the shape along dimension 0 is the same
assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list"
# - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2
assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2"
return tokens_list
with gr.Blocks() as ui:
# File input to upload or download the music tokens file
tokens = gr.File(label='music_tokens_file')
# Audio output to play or upload the generated audio
audio = gr.Audio(label='audio')
# Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
gr.Button(label="Convert tokens to audio", primary=True).click(Convert.TokensFile.to_audio, tokens, audio)
gr.Button(label="Convert audio to tokens", primary=False).click(Convert.Audio.to_tokens_file, audio, tokens)
if __name__ == '__main__':
init()
ui.launch()