asigalov61 commited on
Commit
efd6b77
1 Parent(s): 8ac01e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -40
app.py CHANGED
@@ -25,12 +25,13 @@ in_space = os.getenv("SYSTEM") == "spaces"
25
 
26
  # =================================================================================================
27
 
 
28
  def generate_drums(notes_times,
29
  max_drums_limit = 8,
30
  num_memory_tokens = 4096,
31
  temperature=0.9):
32
 
33
- x = torch.tensor([notes_times] * 1, dtype=torch.long, device=DEVICE)
34
 
35
  o = 128
36
 
@@ -62,6 +63,44 @@ def GenerateDrums(input_midi, input_num_tokens):
62
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
63
  start_time = reqtime.time()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  fn = os.path.basename(input_midi.name)
66
  fn1 = fn.split('.')[0]
67
 
@@ -239,45 +278,7 @@ if __name__ == "__main__":
239
  opt = parser.parse_args()
240
 
241
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
242
-
243
- print('Loading model...')
244
-
245
- SEQ_LEN = 8192 # Models seq len
246
- PAD_IDX = 393 # Models pad index
247
- DEVICE = 'cuda' # 'cuda'
248
-
249
- # instantiate the model
250
-
251
- model = TransformerWrapper(
252
- num_tokens = PAD_IDX+1,
253
- max_seq_len = SEQ_LEN,
254
- attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
255
- )
256
-
257
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
258
-
259
- model.to(DEVICE)
260
- print('=' * 70)
261
-
262
- print('Loading model checkpoint...')
263
-
264
- model.load_state_dict(
265
- torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
266
- map_location=DEVICE))
267
- print('=' * 70)
268
-
269
- model.eval()
270
-
271
- if DEVICE == 'cpu':
272
- dtype = torch.bfloat16
273
- else:
274
- dtype = torch.float16
275
-
276
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
277
-
278
- print('Done!')
279
- print('=' * 70)
280
-
281
  app = gr.Blocks()
282
  with app:
283
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate Drums Transformer</h1>")
 
25
 
26
  # =================================================================================================
27
 
28
+ @spaces.GPU
29
  def generate_drums(notes_times,
30
  max_drums_limit = 8,
31
  num_memory_tokens = 4096,
32
  temperature=0.9):
33
 
34
+ x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cuda')
35
 
36
  o = 128
37
 
 
63
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
64
  start_time = reqtime.time()
65
 
66
+ print('Loading model...')
67
+
68
+ SEQ_LEN = 8192 # Models seq len
69
+ PAD_IDX = 393 # Models pad index
70
+ DEVICE = 'cuda' # 'cuda'
71
+
72
+ # instantiate the model
73
+
74
+ model = TransformerWrapper(
75
+ num_tokens = PAD_IDX+1,
76
+ max_seq_len = SEQ_LEN,
77
+ attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
78
+ )
79
+
80
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
81
+
82
+ model.to(DEVICE)
83
+ print('=' * 70)
84
+
85
+ print('Loading model checkpoint...')
86
+
87
+ model.load_state_dict(
88
+ torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
89
+ map_location=DEVICE))
90
+ print('=' * 70)
91
+
92
+ model.eval()
93
+
94
+ if DEVICE == 'cpu':
95
+ dtype = torch.bfloat16
96
+ else:
97
+ dtype = torch.float16
98
+
99
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
100
+
101
+ print('Done!')
102
+ print('=' * 70)
103
+
104
  fn = os.path.basename(input_midi.name)
105
  fn1 = fn.split('.')[0]
106
 
 
278
  opt = parser.parse_args()
279
 
280
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
281
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  app = gr.Blocks()
283
  with app:
284
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate Drums Transformer</h1>")