Ads67 / app.py
aach456's picture
Create app.py
0593593 verified
raw
history blame
2.21 kB
import gradio as gr
import torch
import numpy as np
from transformers import MusicgenForConditionalGeneration, AutoProcessor
import scipy.io.wavfile
def generate_music(prompt, unconditional=False):
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
# Generate music
if unconditional:
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
else:
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(text=prompt, padding=True, return_tensors="pt")
audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
sampling_rate = model.config.audio_encoder.sampling_rate
audio_file = "musicgen_out.wav"
# Ensure audio_values is 1D and scale if necessary
audio_data = audio_values[0].cpu().numpy()
# Check if audio_data is in the correct format
if audio_data.ndim > 1:
audio_data = audio_data[0] # Take the first channel if stereo
# Scale audio data to 16-bit PCM format
audio_data = np.clip(audio_data, -1.0, 1.0) # Ensure values are in the range [-1, 1]
audio_data = (audio_data * 32767).astype(np.int16) # Scale to int16
# Save the generated audio
scipy.io.wavfile.write(audio_file, sampling_rate, audio_data)
return audio_file
def interface(prompt, unconditional):
audio_file = generate_music(prompt, unconditional)
return audio_file
with gr.Blocks() as demo:
gr.Markdown("# AI-Powered Music Generation")
with gr.Row():
prompt_input = gr.Textbox(label="Enter the Music Prompt")
unconditional_checkbox = gr.Checkbox(label="Generate Unconditional Music")
generate_button = gr.Button("Generate Music")
output_audio = gr.Audio(label="Output Music")
generate_button.click(
interface,
inputs=[prompt_input, unconditional_checkbox],
outputs=output_audio,
show_progress=True
)
demo.launch(share=True)