skytnt commited on
Commit
5c45beb
1 Parent(s): 38cc15f
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -189,24 +189,22 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
189
  init_msgs += [create_msg("visualizer_clear", tokenizer.version),
190
  create_msg("visualizer_append", events)]
191
  yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
192
- ctx = torch.amp.autocast(device_type=opt.device, dtype=torch.bfloat16, enabled=opt.device != "cpu")
193
- with ctx:
194
- midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
195
- disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
196
- disable_channels=disable_channels, generator=generator)
197
- events = []
198
- t = time.time() + 1
199
- for i, token_seq in enumerate(midi_generator):
200
- token_seq = token_seq.tolist()
201
- mid_seq.append(token_seq)
202
- events.append(tokenizer.tokens2event(token_seq))
203
- ct = time.time()
204
- if ct - t > 0.5:
205
- yield (mid_seq, continuation_state, None, None, seed,
206
- send_msgs([create_msg("visualizer_append", events),
207
- create_msg("progress", [i + 1, gen_events])]))
208
- t = ct
209
- events = []
210
 
211
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
212
  mid = tokenizer.detokenize(mid_seq)
@@ -307,6 +305,10 @@ if __name__ == "__main__":
307
  }
308
  models = {}
309
  if opt.device == "cuda":
 
 
 
 
310
  torch.backends.cuda.enable_mem_efficient_sdp(True)
311
  torch.backends.cuda.enable_flash_sdp(True)
312
  for name, (repo_id, path, config) in models_info.items():
@@ -315,7 +317,7 @@ if __name__ == "__main__":
315
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
316
  state_dict = ckpt.get("state_dict", ckpt)
317
  model.load_state_dict(state_dict, strict=False)
318
- model.to(device="cpu", dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32)
319
  models[name] = model
320
 
321
  load_javascript()
 
189
  init_msgs += [create_msg("visualizer_clear", tokenizer.version),
190
  create_msg("visualizer_append", events)]
191
  yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
192
+ midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
193
+ disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
194
+ disable_channels=disable_channels, generator=generator)
195
+ events = []
196
+ t = time.time() + 1
197
+ for i, token_seq in enumerate(midi_generator):
198
+ token_seq = token_seq.tolist()
199
+ mid_seq.append(token_seq)
200
+ events.append(tokenizer.tokens2event(token_seq))
201
+ ct = time.time()
202
+ if ct - t > 0.5:
203
+ yield (mid_seq, continuation_state, None, None, seed,
204
+ send_msgs([create_msg("visualizer_append", events),
205
+ create_msg("progress", [i + 1, gen_events])]))
206
+ t = ct
207
+ events = []
 
 
208
 
209
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
210
  mid = tokenizer.detokenize(mid_seq)
 
305
  }
306
  models = {}
307
  if opt.device == "cuda":
308
+ torch.backends.cudnn.deterministic = True
309
+ torch.backends.cudnn.benchmark = False
310
+ torch.backends.cuda.matmul.allow_tf32 = True
311
+ torch.backends.cudnn.allow_tf32 = True
312
  torch.backends.cuda.enable_mem_efficient_sdp(True)
313
  torch.backends.cuda.enable_flash_sdp(True)
314
  for name, (repo_id, path, config) in models_info.items():
 
317
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
318
  state_dict = ckpt.get("state_dict", ckpt)
319
  model.load_state_dict(state_dict, strict=False)
320
+ model.to(device="cpu", dtype=torch.float32)
321
  models[name] = model
322
 
323
  load_javascript()