Spaces:
Running
Running
asigalov61
commited on
Commit
•
0304307
1
Parent(s):
9d6cd5b
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,8 @@ import time as reqtime
|
|
4 |
import datetime
|
5 |
from pytz import timezone
|
6 |
|
7 |
-
import torch
|
8 |
-
|
9 |
-
import spaces
|
10 |
import gradio as gr
|
11 |
|
12 |
-
from x_transformer_1_23_2 import *
|
13 |
import random
|
14 |
import tqdm
|
15 |
|
@@ -22,48 +18,11 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
22 |
|
23 |
# =================================================================================================
|
24 |
|
25 |
-
@spaces.GPU
|
26 |
def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = reqtime.time()
|
30 |
|
31 |
-
print('Loading model...')
|
32 |
-
|
33 |
-
SEQ_LEN = 8192 # Models seq len
|
34 |
-
PAD_IDX = 707 # Models pad index
|
35 |
-
DEVICE = 'cuda' # 'cuda'
|
36 |
-
|
37 |
-
# instantiate the model
|
38 |
-
|
39 |
-
model = TransformerWrapper(
|
40 |
-
num_tokens = PAD_IDX+1,
|
41 |
-
max_seq_len = SEQ_LEN,
|
42 |
-
attn_layers = Decoder(dim = 2048, depth = 4, heads = 16, attn_flash = True)
|
43 |
-
)
|
44 |
-
|
45 |
-
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
46 |
-
|
47 |
-
model.to(DEVICE)
|
48 |
-
print('=' * 70)
|
49 |
-
|
50 |
-
print('Loading model checkpoint...')
|
51 |
-
|
52 |
-
model.load_state_dict(
|
53 |
-
torch.load('Chords_Progressions_Transformer_Small_2048_Trained_Model_12947_steps_0.9316_loss_0.7386_acc.pth',
|
54 |
-
map_location=DEVICE))
|
55 |
-
print('=' * 70)
|
56 |
-
|
57 |
-
model.eval()
|
58 |
-
|
59 |
-
if DEVICE == 'cpu':
|
60 |
-
dtype = torch.bfloat16
|
61 |
-
else:
|
62 |
-
dtype = torch.float16
|
63 |
-
|
64 |
-
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
65 |
-
|
66 |
-
print('Done!')
|
67 |
print('=' * 70)
|
68 |
|
69 |
fn = os.path.basename(input_midi.name)
|
@@ -363,15 +322,12 @@ if __name__ == "__main__":
|
|
363 |
|
364 |
app = gr.Blocks()
|
365 |
with app:
|
366 |
-
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>
|
367 |
-
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>
|
368 |
gr.Markdown(
|
369 |
-
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.
|
370 |
-
"
|
371 |
-
"Check out [
|
372 |
-
"[Open In Colab]"
|
373 |
-
"(https://colab.research.google.com/github/asigalov61/Chords-Progressions-Transformer/blob/main/Chords_Progressions_Transformer.ipynb)"
|
374 |
-
" for faster execution and endless generation"
|
375 |
)
|
376 |
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
377 |
|
|
|
4 |
import datetime
|
5 |
from pytz import timezone
|
6 |
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
|
|
|
9 |
import random
|
10 |
import tqdm
|
11 |
|
|
|
18 |
|
19 |
# =================================================================================================
|
20 |
|
|
|
21 |
def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
|
22 |
print('=' * 70)
|
23 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
24 |
start_time = reqtime.time()
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
print('=' * 70)
|
27 |
|
28 |
fn = os.path.basename(input_midi.name)
|
|
|
322 |
|
323 |
app = gr.Blocks()
|
324 |
with app:
|
325 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Melody</h1>")
|
326 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Add a unique melody to any MIDI</h1>")
|
327 |
gr.Markdown(
|
328 |
+
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Melody&style=flat)\n\n"
|
329 |
+
"This is a demo for TMIDIX Python module from tegridy-tools\n\n"
|
330 |
+
"Check out [tegridy-tools](https://github.com/asigalov61/tegridy-tools) on GitHub!\n\n"
|
|
|
|
|
|
|
331 |
)
|
332 |
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
333 |
|