Spaces:
Runtime error
Runtime error
import random | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr | |
import torch | |
from spectro import wav_bytes_from_spectrogram_image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype) | |
pipe = pipe.to(device) | |
model_id2 = "riffusion/riffusion-model-v1" | |
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype) | |
pipe2 = pipe2.to(device) | |
COLORS = [ | |
["#ff0000", "#00ff00"], | |
["#00ff00", "#0000ff"], | |
["#0000ff", "#ff0000"], | |
] | |
title = """ | |
<div style="text-align: center; max-width: 650px; margin: 0 auto 10px;"> | |
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"> | |
<h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 98%; color: #666;">Text to music player.</p> | |
</div> | |
""" | |
def get_bg_image(prompt): | |
images = pipe(prompt) | |
print("Image generated!") | |
image_output = images.images[0] if not images.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg") | |
return image_output | |
def get_music(prompt): | |
spec = pipe2(prompt).images[0] | |
print(spec) | |
wav = wav_bytes_from_spectrogram_image(spec) | |
with open("output.wav", "wb") as f: | |
f.write(wav[0].getbuffer()) | |
return 'output.wav' | |
def infer(prompt): | |
image = get_bg_image(prompt) | |
audio = get_music(prompt) | |
return ( | |
gr.make_waveform(audio, bg_image=image, bars_color=random.choice(COLORS)), | |
) | |
css = """ | |
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;} | |
#prompt-in { | |
border: 2px solid #666; | |
border-radius: 2px; | |
padding: 8px; | |
} | |
#btn-container { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
width: calc(15% - 16px); | |
height: calc(15% - 16px); | |
} | |
/* Style the submit button */ | |
#submit-btn { | |
background-color: #382a1d; | |
color: #fff; | |
border: 1px solid #000; | |
border-radius: 4px; | |
padding: 8px; | |
font-size: 16px; | |
cursor: pointer; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(title) | |
with gr.Column(elem_id="col-container"): | |
prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", | |
elem_id="prompt-in", | |
show_label=False) | |
with gr.Row(elem_id="btn-container"): | |
send_btn = gr.Button(value="Send", elem_id="submit-btn") | |
video_output = gr.Video() | |
send_btn.click(infer, inputs=[prompt_input], outputs=[video_output]) | |
demo.queue().launch(debug=True) |