asigalov61
commited on
Commit
•
efd6b77
1
Parent(s):
8ac01e3
Update app.py
Browse files
app.py
CHANGED
@@ -25,12 +25,13 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
25 |
|
26 |
# =================================================================================================
|
27 |
|
|
|
28 |
def generate_drums(notes_times,
|
29 |
max_drums_limit = 8,
|
30 |
num_memory_tokens = 4096,
|
31 |
temperature=0.9):
|
32 |
|
33 |
-
x = torch.tensor([notes_times] * 1, dtype=torch.long, device=
|
34 |
|
35 |
o = 128
|
36 |
|
@@ -62,6 +63,44 @@ def GenerateDrums(input_midi, input_num_tokens):
|
|
62 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
63 |
start_time = reqtime.time()
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
fn = os.path.basename(input_midi.name)
|
66 |
fn1 = fn.split('.')[0]
|
67 |
|
@@ -239,45 +278,7 @@ if __name__ == "__main__":
|
|
239 |
opt = parser.parse_args()
|
240 |
|
241 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
242 |
-
|
243 |
-
print('Loading model...')
|
244 |
-
|
245 |
-
SEQ_LEN = 8192 # Models seq len
|
246 |
-
PAD_IDX = 393 # Models pad index
|
247 |
-
DEVICE = 'cuda' # 'cuda'
|
248 |
-
|
249 |
-
# instantiate the model
|
250 |
-
|
251 |
-
model = TransformerWrapper(
|
252 |
-
num_tokens = PAD_IDX+1,
|
253 |
-
max_seq_len = SEQ_LEN,
|
254 |
-
attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
|
255 |
-
)
|
256 |
-
|
257 |
-
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
258 |
-
|
259 |
-
model.to(DEVICE)
|
260 |
-
print('=' * 70)
|
261 |
-
|
262 |
-
print('Loading model checkpoint...')
|
263 |
-
|
264 |
-
model.load_state_dict(
|
265 |
-
torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
|
266 |
-
map_location=DEVICE))
|
267 |
-
print('=' * 70)
|
268 |
-
|
269 |
-
model.eval()
|
270 |
-
|
271 |
-
if DEVICE == 'cpu':
|
272 |
-
dtype = torch.bfloat16
|
273 |
-
else:
|
274 |
-
dtype = torch.float16
|
275 |
-
|
276 |
-
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
277 |
-
|
278 |
-
print('Done!')
|
279 |
-
print('=' * 70)
|
280 |
-
|
281 |
app = gr.Blocks()
|
282 |
with app:
|
283 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate Drums Transformer</h1>")
|
|
|
25 |
|
26 |
# =================================================================================================
|
27 |
|
28 |
+
@spaces.GPU
|
29 |
def generate_drums(notes_times,
|
30 |
max_drums_limit = 8,
|
31 |
num_memory_tokens = 4096,
|
32 |
temperature=0.9):
|
33 |
|
34 |
+
x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cuda')
|
35 |
|
36 |
o = 128
|
37 |
|
|
|
63 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
64 |
start_time = reqtime.time()
|
65 |
|
66 |
+
print('Loading model...')
|
67 |
+
|
68 |
+
SEQ_LEN = 8192 # Models seq len
|
69 |
+
PAD_IDX = 393 # Models pad index
|
70 |
+
DEVICE = 'cuda' # 'cuda'
|
71 |
+
|
72 |
+
# instantiate the model
|
73 |
+
|
74 |
+
model = TransformerWrapper(
|
75 |
+
num_tokens = PAD_IDX+1,
|
76 |
+
max_seq_len = SEQ_LEN,
|
77 |
+
attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
|
78 |
+
)
|
79 |
+
|
80 |
+
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
81 |
+
|
82 |
+
model.to(DEVICE)
|
83 |
+
print('=' * 70)
|
84 |
+
|
85 |
+
print('Loading model checkpoint...')
|
86 |
+
|
87 |
+
model.load_state_dict(
|
88 |
+
torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
|
89 |
+
map_location=DEVICE))
|
90 |
+
print('=' * 70)
|
91 |
+
|
92 |
+
model.eval()
|
93 |
+
|
94 |
+
if DEVICE == 'cpu':
|
95 |
+
dtype = torch.bfloat16
|
96 |
+
else:
|
97 |
+
dtype = torch.float16
|
98 |
+
|
99 |
+
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
100 |
+
|
101 |
+
print('Done!')
|
102 |
+
print('=' * 70)
|
103 |
+
|
104 |
fn = os.path.basename(input_midi.name)
|
105 |
fn1 = fn.split('.')[0]
|
106 |
|
|
|
278 |
opt = parser.parse_args()
|
279 |
|
280 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
281 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
app = gr.Blocks()
|
283 |
with app:
|
284 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate Drums Transformer</h1>")
|