sd-riffusion / app.py
juancopi81's picture
Add brackets to output
f18cdd2
raw
history blame
2.92 kB
import random
from PIL import Image
from diffusers import StableDiffusionPipeline
import gradio as gr
import torch
from spectro import wav_bytes_from_spectrogram_image
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to(device)
model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype)
pipe2 = pipe2.to(device)
COLORS = [
["#ff0000", "#00ff00"],
["#00ff00", "#0000ff"],
["#0000ff", "#ff0000"],
]
title = """
<div style="text-align: center; max-width: 650px; margin: 0 auto 10px;">
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
<h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
</div>
<p style="margin-bottom: 10px; font-size: 98%; color: #666;">Text to music player.</p>
</div>
"""
def get_bg_image(prompt):
images = pipe(prompt)
print("Image generated!")
image_output = images.images[0] if not images.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg")
return image_output
def get_music(prompt):
spec = pipe2(prompt).images[0]
print(spec)
wav = wav_bytes_from_spectrogram_image(spec)
with open("output.wav", "wb") as f:
f.write(wav[0].getbuffer())
return 'output.wav'
def infer(prompt):
image = get_bg_image(prompt)
audio = get_music(prompt)
return (
gr.make_waveform(audio, bg_image=image, bars_color=random.choice(COLORS)),
)
css = """
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
#prompt-in {
border: 2px solid #666;
border-radius: 2px;
padding: 8px;
}
#btn-container {
display: flex;
align-items: center;
justify-content: center;
width: calc(15% - 16px);
height: calc(15% - 16px);
}
/* Style the submit button */
#submit-btn {
background-color: #382a1d;
color: #fff;
border: 1px solid #000;
border-radius: 4px;
padding: 8px;
font-size: 16px;
cursor: pointer;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Column(elem_id="col-container"):
prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club",
elem_id="prompt-in",
show_label=False)
with gr.Row(elem_id="btn-container"):
send_btn = gr.Button(value="Send", elem_id="submit-btn")
video_output = gr.Video()
send_btn.click(infer, inputs=[prompt_input], outputs=[video_output])
demo.queue().launch(debug=True)