sanchit-gandhi HF staff commited on
Commit
91a1e69
1 Parent(s): 7e248e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import MusicgenForConditionalGeneration, AutoProcessor, set_seed
2
  import torch
 
3
  import gradio as gr
4
 
5
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
@@ -9,10 +10,11 @@ device = "cuda:0"
9
  model.to(device)
10
 
11
  sampling_rate = model.audio_encoder.config.sampling_rate
 
12
  text_encoder = model.get_text_encoder()
13
 
14
 
15
- def generate_audio(prompt, negative_prompt, guidance_scale=3, seed=0):
16
  inputs = processor(
17
  text=[prompt, negative_prompt],
18
  padding=True,
@@ -22,8 +24,10 @@ def generate_audio(prompt, negative_prompt, guidance_scale=3, seed=0):
22
  with torch.no_grad():
23
  encoder_outputs = text_encoder(**inputs)
24
 
 
 
25
  set_seed(seed)
26
- audio_values = model.generate(inputs.input_ids[0][None, :], attention_mask=inputs.attention_mask, encoder_outputs=encoder_outputs, do_sample=True, guidance_scale=guidance_scale, max_new_tokens=1028)
27
 
28
  audio_values = (audio_values.cpu().numpy() * 32767).astype(np.int16)
29
  return (sampling_rate, audio_values)
@@ -35,6 +39,7 @@ gr.Interface(
35
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
36
  gr.Text(label="Negative prompt", value="drums"),
37
  gr.Slider(1.5, 10, value=3, step=0.5, label="Guidance scale"),
 
38
  gr.Slider(0, 10, value=0, step=1, label="Seed"),
39
  ],
40
  outputs=[
 
1
  from transformers import MusicgenForConditionalGeneration, AutoProcessor, set_seed
2
  import torch
3
+ import numpy as np
4
  import gradio as gr
5
 
6
  model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
10
  model.to(device)
11
 
12
  sampling_rate = model.audio_encoder.config.sampling_rate
13
+ frame_rate = model.audio_encoder.config.frame_rate
14
  text_encoder = model.get_text_encoder()
15
 
16
 
17
+ def generate_audio(prompt, negative_prompt, guidance_scale=3, audio_length_in_s=20, seed=0):
18
  inputs = processor(
19
  text=[prompt, negative_prompt],
20
  padding=True,
 
24
  with torch.no_grad():
25
  encoder_outputs = text_encoder(**inputs)
26
 
27
+ max_new_tokens = int(frame_rate * audio_length_in_s)
28
+
29
  set_seed(seed)
30
+ audio_values = model.generate(inputs.input_ids[0][None, :], attention_mask=inputs.attention_mask, encoder_outputs=encoder_outputs, do_sample=True, guidance_scale=guidance_scale, max_new_tokens=max_new_tokens)
31
 
32
  audio_values = (audio_values.cpu().numpy() * 32767).astype(np.int16)
33
  return (sampling_rate, audio_values)
 
39
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
40
  gr.Text(label="Negative prompt", value="drums"),
41
  gr.Slider(1.5, 10, value=3, step=0.5, label="Guidance scale"),
42
+ gr.Slider(5, 30, value=15, step=5, label="Audio length in s"),
43
  gr.Slider(0, 10, value=0, step=1, label="Seed"),
44
  ],
45
  outputs=[