asigalov61 commited on
Commit
3dd0e68
·
1 Parent(s): 94706c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -21
app.py CHANGED
@@ -23,7 +23,7 @@ def generate(
23
  seq_len,
24
  max_seq_len = 2048,
25
  temperature = 0.9,
26
- verbose=True,
27
  return_prime=False,
28
  ):
29
 
@@ -34,25 +34,30 @@ def generate(
34
  if verbose:
35
  print("Generating sequence of max length:", seq_len)
36
 
37
- for s in range(seq_len):
38
- x = out[:, -max_seq_len:]
39
-
40
- torch_in = x.tolist()[0]
41
-
42
- logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
43
 
44
- filtered_logits = logits
45
-
46
- probs = F.softmax(filtered_logits / temperature, dim=-1)
47
-
48
- sample = torch.multinomial(probs, 1)
49
-
50
- out = torch.cat((out, sample), dim=-1)
51
-
52
- if verbose:
53
- if s % 32 == 0:
54
- print(s, '/', seq_len)
55
-
 
 
 
 
 
 
 
 
 
56
  if return_prime:
57
  return out[:, :]
58
 
@@ -134,7 +139,7 @@ def GenerateMIDI():
134
 
135
  midi_data = TMIDIX.score2midi(output, text_encoding)
136
 
137
- with open(f"Allegro-Music-Transformer-Music-Composition", 'wb') as f:
138
  f.write(midi_data)
139
 
140
  audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
@@ -191,8 +196,10 @@ if __name__ == "__main__":
191
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
192
  opt = parser.parse_args()
193
 
 
194
  session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
195
-
 
196
  load_javascript()
197
  app = gr.Blocks()
198
  with app:
 
23
  seq_len,
24
  max_seq_len = 2048,
25
  temperature = 0.9,
26
+ verbose=False,
27
  return_prime=False,
28
  ):
29
 
 
34
  if verbose:
35
  print("Generating sequence of max length:", seq_len)
36
 
37
+ max_len = seq_len
38
+ cur_len = 0
 
 
 
 
39
 
40
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
41
+ with bar:
42
+ while cur_len < max_len:
43
+
44
+ x = out[:, -max_seq_len:]
45
+
46
+ torch_in = x.tolist()[0]
47
+
48
+ logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
49
+
50
+ filtered_logits = logits
51
+
52
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
53
+
54
+ sample = torch.multinomial(probs, 1)
55
+
56
+ out = torch.cat((out, sample), dim=-1)
57
+
58
+ cur_len += 1
59
+ bar.update(1)
60
+
61
  if return_prime:
62
  return out[:, :]
63
 
 
139
 
140
  midi_data = TMIDIX.score2midi(output, text_encoding)
141
 
142
+ with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
143
  f.write(midi_data)
144
 
145
  audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
 
196
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
197
  opt = parser.parse_args()
198
 
199
+ print('Loading model...')
200
  session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
201
+ print('Done!')
202
+
203
  load_javascript()
204
  app = gr.Blocks()
205
  with app: