root commited on
Commit
1a39554
1 Parent(s): 2cb7990

add picture generation for background

Browse files
Files changed (2) hide show
  1. app.py +23 -7
  2. requirements.txt +4 -1
app.py CHANGED
@@ -3,30 +3,46 @@ import spaces
3
  import gradio as gr
4
 
5
  from transformers import MusicgenForConditionalGeneration
6
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
7
- sampling_rate = model.config.audio_encoder.sampling_rate
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model.to(device)
11
 
12
  from transformers import AutoProcessor
13
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
14
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
  def generate_music(desc):
 
17
  inputs = processor(text=[desc], padding=True, return_tensors="pt")
18
- audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
19
  return sampling_rate, audio_values[0][0].cpu().numpy()
20
 
21
- # gr.Interface(fn=generate_music, inputs="text", outputs="audio").launch()
 
 
22
 
23
  with gr.Blocks() as app:
24
  with gr.Row():
25
  music_desc = gr.TextArea(label="Music Description")
 
26
  music_player = gr.Audio(label="Play My Tune")
27
 
28
- gen_btn = gr.Button("Get Some Tune!!")
29
- gen_btn.click(fn=generate_music, inputs=[music_desc], outputs=[music_player])
 
 
 
 
30
 
31
  if __name__ == '__main__':
32
  app.launch()
 
3
  import gradio as gr
4
 
5
  from transformers import MusicgenForConditionalGeneration
6
+ music_gen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
7
+ sampling_rate = music_gen_model.config.audio_encoder.sampling_rate
8
+
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ music_gen_model.to(device)
12
 
13
  from transformers import AutoProcessor
14
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
15
 
16
+ from diffusers import DiffusionPipeline
17
+
18
+ # sd_pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
19
+ sd_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
20
+ sd_pipe.to(device)
21
+
22
+
23
  @spaces.GPU
24
  def generate_music(desc):
25
+
26
  inputs = processor(text=[desc], padding=True, return_tensors="pt")
27
+ audio_values = music_gen_model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
28
  return sampling_rate, audio_values[0][0].cpu().numpy()
29
 
30
+ @spaces.GPU
31
+ def generate_pic(desc):
32
+ return sd_pipe(prompt=desc).images[0]
33
 
34
  with gr.Blocks() as app:
35
  with gr.Row():
36
  music_desc = gr.TextArea(label="Music Description")
37
+ music_pic = gr.Image(label="Music Image(StableDiffusion)")
38
  music_player = gr.Audio(label="Play My Tune")
39
 
40
+
41
+ gen_pic_btn = gr.Button("Gen Picture")
42
+ gen_music_btn = gr.Button("Get Some Tune!!")
43
+
44
+ gen_pic_btn.click(fn=generate_pic, inputs=[music_desc], outputs=[music_pic])
45
+ gen_music_btn.click(fn=generate_music, inputs=[music_desc], outputs=[music_player])
46
 
47
  if __name__ == '__main__':
48
  app.launch()
requirements.txt CHANGED
@@ -2,4 +2,7 @@ transformers
2
  torch
3
  torchvision
4
  torchaudio
5
- spaces
 
 
 
 
2
  torch
3
  torchvision
4
  torchaudio
5
+ spaces
6
+ accelerate
7
+ safetensors
8
+ diffusers