asigalov61 commited on
Commit
0e630bf
·
verified ·
1 Parent(s): 904cd6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -95,11 +95,13 @@ def GenerateMIDI(num_tok, idrums, iinstr):
95
 
96
  inp = torch.LongTensor([outy]).cpu()
97
 
98
- out = model.module.generate(inp,
99
- 1,
100
- temperature=0.9,
101
- return_prime=False,
102
- verbose=False)
 
 
103
 
104
  out0 = out[0].tolist()
105
  outy.extend(out0)
@@ -177,12 +179,14 @@ if __name__ == "__main__":
177
  print('Loading model checkpoint...')
178
 
179
  model.load_state_dict(
180
- torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
181
  map_location='cpu'))
182
  print('=' * 70)
183
 
184
  model.eval()
185
 
 
 
186
  print('Done!')
187
  print('=' * 70)
188
 
 
95
 
96
  inp = torch.LongTensor([outy]).cpu()
97
 
98
+ with ctx:
99
+
100
+ out = model.module.generate(inp,
101
+ 1,
102
+ temperature=0.9,
103
+ return_prime=False,
104
+ verbose=False)
105
 
106
  out0 = out[0].tolist()
107
  outy.extend(out0)
 
179
  print('Loading model checkpoint...')
180
 
181
  model.load_state_dict(
182
+ torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_8134_steps_0.3745_loss_0.8736_acc.pth',
183
  map_location='cpu'))
184
  print('=' * 70)
185
 
186
  model.eval()
187
 
188
+ ctx = torch.amp.autocast(device_type='cpu', dtype=torch.bfloat16)
189
+
190
  print('Done!')
191
  print('=' * 70)
192