import random from PIL import Image from diffusers import StableDiffusionPipeline import gradio as gr import torch device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, revision="fp16") pipe = pipe.to(device) model_id2 = "riffusion/riffusion-model-v1" pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype) pipe2 = pipe2.to(device) COLORS = [ ["#ff0000", "#00ff00"], ["#00ff00", "#0000ff"], ["#0000ff", "#ff0000"], ] title = """

Riffusion and Stable Diffusion

""" def get_bg_image(prompt): images = pipe(prompt) print("Image generated!") image_output = images.images[0] if not images.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg") return image_output def get_music(prompt): spec = pipe2(prompt, height=512, width=512).images[0] print(spec) wav = wav_bytes_from_spectrogram_image(spec) with open("output.wav", "wb") as f: f.write(wav[0].getbuffer()) return "output.wav" def infer(prompt): image = get_bg_image(prompt) audio = get_music(prompt) return ( gr.make_waveform("output.wav", bg_image=image, bars_color=random.choice(COLORS)), ) css = """ #col-container {max-width: 700px; margin-left: auto; margin-right: auto;} #prompt-in { border: 2px solid #666; border-radius: 2px; padding: 8px; } #btn-container { display: flex; align-items: center; justify-content: center; width: calc(15% - 16px); height: calc(15% - 16px); } /* Style the submit button */ #submit-btn { background-color: #382a1d; color: #fff; border: 1px solid #000; border-radius: 4px; padding: 8px; font-size: 16px; cursor: pointer; } """ with gr.Blocks(css=css) as demo: gr.HTML(title) with gr.Column(elem_id="col-container"): prompt_input = gr.Textbox(placeholder="The Beatles playing for the queen", elem_id="prompt-in", label="Enter your music prompt") with gr.Row(elem_id="btn-container"): send_btn = gr.Button(value="Send", elem_id="submit-btn") send_btn.click(infer, inputs=[prompt_input], outputs=[gr.Video()]) demo.launch().debug(True)