File size: 2,917 Bytes
27bce5c
3e79bbd
 
 
 
 
 
cbb24c2
 
3e79bbd
 
27bce5c
3e79bbd
 
 
 
 
 
 
27bce5c
 
 
 
 
 
 
3e79bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33d0b35
3e79bbd
27bce5c
 
 
 
3e79bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f18cdd2
27bce5c
3e79bbd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import random
from PIL import Image

from diffusers import StableDiffusionPipeline
import gradio as gr
import torch

from spectro import wav_bytes_from_spectrogram_image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to(device)

model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype)
pipe2 = pipe2.to(device)

COLORS = [
    ["#ff0000", "#00ff00"],
    ["#00ff00", "#0000ff"],
    ["#0000ff", "#ff0000"],
]    

title = """
        <div style="text-align: center; max-width: 650px; margin: 0 auto 10px;">
          <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
            <h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
          </div>
          <p style="margin-bottom: 10px; font-size: 98%; color: #666;">Text to music player.</p>
        </div>
        """
def get_bg_image(prompt):
    images = pipe(prompt)
    print("Image generated!")
    image_output = images.images[0] if not images.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg")
    return image_output

def get_music(prompt):
    spec = pipe2(prompt).images[0]
    print(spec)
    wav = wav_bytes_from_spectrogram_image(spec)
    with open("output.wav", "wb") as f:
      f.write(wav[0].getbuffer())
    return 'output.wav'

def infer(prompt):
    image = get_bg_image(prompt)
    audio = get_music(prompt)
    return (
        gr.make_waveform(audio, bg_image=image, bars_color=random.choice(COLORS)),
    )

css = """
    #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
    #prompt-in {
        border: 2px solid #666;
        border-radius: 2px;
        padding: 8px;
    }
    #btn-container {
      display: flex;
      align-items: center;
      justify-content: center;
      width: calc(15% - 16px);
      height: calc(15% - 16px);
    }
    /* Style the submit button */
    #submit-btn {
      background-color: #382a1d;
      color: #fff;
      border: 1px solid #000;
      border-radius: 4px;
      padding: 8px;
      font-size: 16px;
      cursor: pointer;
    }
"""
with gr.Blocks(css=css) as demo:
    gr.HTML(title)
    with gr.Column(elem_id="col-container"):
      prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club",
                                elem_id="prompt-in",
                                show_label=False)
      with gr.Row(elem_id="btn-container"):
        send_btn = gr.Button(value="Send", elem_id="submit-btn")      
      video_output = gr.Video()
      send_btn.click(infer, inputs=[prompt_input], outputs=[video_output])

demo.queue().launch(debug=True)