import gradio as gr
import torch
from spectro import wav_bytes_from_spectrogram_image
from diffusers import StableDiffusionPipeline
from transformers import BlipForConditionalGeneration, BlipProcessor
from share_btn import community_icon_html, loading_icon_html, share_js
model_id = "riffusion/riffusion-model-v1"
blip_model_id = "Salesforce/blip-image-captioning-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cpu") #cuda
blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_id, torch_dtype=torch.float).to("CPU") #CUDA
processor = BlipProcessor.from_pretrained(blip_model_id)
def predict(image):
inputs = processor(image, return_tensors="pt").to("cpu", torch.float) #cuda
output_blip = blip_model.generate(**inputs)
prompt = processor.decode(output_blip[0], skip_special_tokens=True)
spec = pipe(prompt).images[0]
print(spec)
wav = wav_bytes_from_spectrogram_image(spec)
with open("output.wav", "wb") as f:
f.write(wav[0].getbuffer())
return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
title = """
Riffusion real-time prompt to image and to music generation system
Describe a musical prompt and generate a respective spectrogram image & musical sound associated with.
"""
article = """
About the model: Riffusion is a latent text2img diffusion model capable of generating spectrogram images from a given text input prompts. These generated spectrograms are again then utilised to get converted into audio clips.
—
The Riffusion model was created by fine-tuning the Stable-Diffusion-v1-5 checkpoint.
—
The model is intended for research purposes only. Possible research areas and tasks include
generation of artworks, audio, and use in creative processes, applications in educational or creative tools, research on generative models.
"""
css = '''
#col-container, #col-container-2 {max-width: 510px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
div#record_btn > .mt-6 {
margin-top: 0!important;
}
div#record_btn > .mt-6 button {
width: 100%;
height: 40px;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
# prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
image_input = gr.Image()
send_btn = gr.Button(value="Get a new riffusion spectrogram ! ", elem_id="submit-btn")
with gr.Column(elem_id="col-container-2"):
spectrogram_output = gr.Image(label="riffusion spectrogram image result", elem_id="img-out")
sound_output = gr.Audio(type='filepath', label="riffusion spectrogram sound", elem_id="music-out")
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=False)
loading_icon = gr.HTML(loading_icon_html, visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
gr.HTML(article)
send_btn.click(predict, inputs=[image_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
share_button.click(None, [], [], _js=share_js)
demo.queue(max_size=250).launch(debug=True)