Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
·
3dd0e68
1
Parent(s):
94706c2
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ def generate(
|
|
23 |
seq_len,
|
24 |
max_seq_len = 2048,
|
25 |
temperature = 0.9,
|
26 |
-
verbose=
|
27 |
return_prime=False,
|
28 |
):
|
29 |
|
@@ -34,25 +34,30 @@ def generate(
|
|
34 |
if verbose:
|
35 |
print("Generating sequence of max length:", seq_len)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
torch_in = x.tolist()[0]
|
41 |
-
|
42 |
-
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if return_prime:
|
57 |
return out[:, :]
|
58 |
|
@@ -134,7 +139,7 @@ def GenerateMIDI():
|
|
134 |
|
135 |
midi_data = TMIDIX.score2midi(output, text_encoding)
|
136 |
|
137 |
-
with open(f"Allegro-Music-Transformer-Music-Composition", 'wb') as f:
|
138 |
f.write(midi_data)
|
139 |
|
140 |
audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
|
@@ -191,8 +196,10 @@ if __name__ == "__main__":
|
|
191 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
192 |
opt = parser.parse_args()
|
193 |
|
|
|
194 |
session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
|
195 |
-
|
|
|
196 |
load_javascript()
|
197 |
app = gr.Blocks()
|
198 |
with app:
|
|
|
23 |
seq_len,
|
24 |
max_seq_len = 2048,
|
25 |
temperature = 0.9,
|
26 |
+
verbose=False,
|
27 |
return_prime=False,
|
28 |
):
|
29 |
|
|
|
34 |
if verbose:
|
35 |
print("Generating sequence of max length:", seq_len)
|
36 |
|
37 |
+
max_len = seq_len
|
38 |
+
cur_len = 0
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
41 |
+
with bar:
|
42 |
+
while cur_len < max_len:
|
43 |
+
|
44 |
+
x = out[:, -max_seq_len:]
|
45 |
+
|
46 |
+
torch_in = x.tolist()[0]
|
47 |
+
|
48 |
+
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
|
49 |
+
|
50 |
+
filtered_logits = logits
|
51 |
+
|
52 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
53 |
+
|
54 |
+
sample = torch.multinomial(probs, 1)
|
55 |
+
|
56 |
+
out = torch.cat((out, sample), dim=-1)
|
57 |
+
|
58 |
+
cur_len += 1
|
59 |
+
bar.update(1)
|
60 |
+
|
61 |
if return_prime:
|
62 |
return out[:, :]
|
63 |
|
|
|
139 |
|
140 |
midi_data = TMIDIX.score2midi(output, text_encoding)
|
141 |
|
142 |
+
with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
|
143 |
f.write(midi_data)
|
144 |
|
145 |
audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
|
|
|
196 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
197 |
opt = parser.parse_args()
|
198 |
|
199 |
+
print('Loading model...')
|
200 |
session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
|
201 |
+
print('Done!')
|
202 |
+
|
203 |
load_javascript()
|
204 |
app = gr.Blocks()
|
205 |
with app:
|