File size: 2,841 Bytes
e789c49
 
 
f242ec5
5fb0595
f242ec5
5fb0595
 
 
 
 
 
 
f242ec5
5fb0595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from transformers import JukeboxModel , JukeboxTokenizer
from transformers.models.jukebox import convert_jukebox

import gradio as gr
import torch as t

model_id = 'openai/jukebox-1b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']
sample_rate = 44100
total_duration_in_seconds = 200
raw_to_tokens = 128
chunk_size = 32
max_batch_size = 16
cache_path = '~/.cache/'

def tokens_to_seconds(tokens, level = 2):

  global sample_rate, raw_to_tokens
  return tokens * raw_to_tokens / sample_rate / 4 ** (2 - level)

def seconds_to_tokens(sec, level = 2):

  global sample_rate, raw_to_tokens, chunk_size

  tokens = sec * sample_rate // raw_to_tokens
  tokens = ( (tokens // chunk_size) + 1 ) * chunk_size

  # For levels 1 and 0, multiply by 4 and 16 respectively
  tokens *= 4 ** (2 - level)

  return int(tokens)

# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
  global model

  print(f"Loading model from/to {cache_path}...")
  model = JukeboxModel.from_pretrained(
    model_id,
    device_map = "auto",
    torch_dtype = t.float16,
    cache_dir = f"{cache_path}/jukebox/models",
    resume_download = True,
    min_duration = 0
  ).eval()
  print("Model loaded: ", model)

# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(artist, genres, lyrics):
  global model, zs

  n_samples = 4
  generation_length = seconds_to_tokens(1)
  offset = 0
  level = 0

  model.total_length = seconds_to_tokens(total_duration_in_seconds)

  sampling_kwargs = dict(
    temp = 0.98,
    chunk_size = chunk_size,
  )

  metas = dict(
    artist = artist,
    genres = genres,
    lyrics = lyrics,
  )

  labels = JukeboxTokenizer.from_pretrained(model_id)(**metas)['input_ids'][level].repeat(n_samples, 1).cuda()
  print(f"Labels: {labels.shape}")

  zs = [ t.zeros(n_samples, 0, dtype=t.long, device='cuda') for _ in range(3) ]
  print(f"Zs: {[z.shape for z in zs]}")

  zs = model.sample_partial_window(
    zs, labels, offset, sampling_kwargs, level = level, tokens_to_sample = generation_length, max_batch_size = max_batch_size
  )
  print(f"Zs after sampling: {[z.shape for z in zs]}")

  # Convert to numpy array
  return zs.cpu().numpy()


with gr.Blocks() as ui:

  # Define UI components
  title = gr.Textbox(lines=1, label="Title")
  artist = gr.Textbox(lines=1, label="Artist")
  genres = gr.Textbox(lines=1, label="Genre(s)", placeholder="Separate with spaces")
  lyrics = gr.Textbox(lines=5, label="Lyrics", placeholder="Shift+Enter for new line")
  submit = gr.Button(label="Generate")

  output_zs = gr.Dataframe(label="zs")

  submit.click(
    inference,
    inputs = [ artist, genres, lyrics ],
    outputs = output_zs,
  )

if __name__ == "__main__":

  init()

  gr.launch()