Spaces:
Runtime error
Runtime error
from transformers import MusicgenForConditionalGeneration, AutoProcessor, set_seed | |
import torch | |
import numpy as np | |
import gradio as gr | |
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
device = "cuda:0" | |
model.to(device) | |
sampling_rate = model.audio_encoder.config.sampling_rate | |
frame_rate = model.audio_encoder.config.frame_rate | |
text_encoder = model.get_text_encoder() | |
def generate_audio(prompt, negative_prompt, guidance_scale=3, audio_length_in_s=20, seed=0): | |
inputs = processor( | |
text=[prompt, negative_prompt], | |
padding=True, | |
return_tensors="pt", | |
).to(device) | |
with torch.no_grad(): | |
encoder_outputs = text_encoder(**inputs) | |
max_new_tokens = int(frame_rate * audio_length_in_s) | |
set_seed(seed) | |
audio_values = model.generate(inputs.input_ids[0][None, :], attention_mask=inputs.attention_mask, encoder_outputs=encoder_outputs, do_sample=True, guidance_scale=guidance_scale, max_new_tokens=max_new_tokens) | |
audio_values = (audio_values.cpu().numpy() * 32767).astype(np.int16) | |
return (sampling_rate, audio_values) | |
EXAMPLES = [ | |
["80s pop track with synth and instrumentals", "drums", 1.01, 15, 0], | |
["80s pop track with synth and instrumentals", "drums", 3, 15, 0], | |
["80s pop track with synth and instrumentals", "drums", 5, 15, 0], | |
["80s pop track with synth and instrumentals", "drums", 7, 15, 0], | |
["80s pop track with synth and instrumentals", "drums", 10, 15, 0], | |
] | |
gr.Interface( | |
fn=generate_audio, | |
inputs=[ | |
gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"), | |
gr.Text(label="Negative prompt", value="drums"), | |
gr.Slider(1.5, 10, value=3, step=0.5, label="Guidance scale"), | |
gr.Slider(5, 30, value=15, step=5, label="Audio length in s"), | |
gr.Slider(0, 10, value=0, step=1, label="Seed"), | |
], | |
outputs=[ | |
gr.Audio(label="Generated Music", type="numpy"), | |
], | |
examples=EXAMPLES, | |
).launch() | |