Hua Ruochen commited on
Commit
2cb7990
1 Parent(s): 56798b8

add space gpu

Browse files
Files changed (2) hide show
  1. app.py +14 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import gradio as gr
3
 
4
  from transformers import MusicgenForConditionalGeneration
@@ -11,10 +12,21 @@ model.to(device)
11
  from transformers import AutoProcessor
12
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
13
 
14
-
15
  def generate_music(desc):
16
  inputs = processor(text=[desc], padding=True, return_tensors="pt")
17
  audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
18
  return sampling_rate, audio_values[0][0].cpu().numpy()
19
 
20
- gr.Interface(fn=generate_music, inputs="text", outputs="audio").launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import spaces
3
  import gradio as gr
4
 
5
  from transformers import MusicgenForConditionalGeneration
 
12
  from transformers import AutoProcessor
13
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
14
 
15
+ @spaces.GPU
16
  def generate_music(desc):
17
  inputs = processor(text=[desc], padding=True, return_tensors="pt")
18
  audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
19
  return sampling_rate, audio_values[0][0].cpu().numpy()
20
 
21
+ # gr.Interface(fn=generate_music, inputs="text", outputs="audio").launch()
22
+
23
+ with gr.Blocks() as app:
24
+ with gr.Row():
25
+ music_desc = gr.TextArea(label="Music Description")
26
+ music_player = gr.Audio(label="Play My Tune")
27
+
28
+ gen_btn = gr.Button("Get Some Tune!!")
29
+ gen_btn.click(fn=generate_music, inputs=[music_desc], outputs=[music_player])
30
+
31
+ if __name__ == '__main__':
32
+ app.launch()
requirements.txt CHANGED
@@ -2,3 +2,4 @@ transformers
2
  torch
3
  torchvision
4
  torchaudio
 
 
2
  torch
3
  torchvision
4
  torchaudio
5
+ spaces