Spaces:
Running
on
Zero
Running
on
Zero
tf32
Browse files
app.py
CHANGED
@@ -189,24 +189,22 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
|
|
189 |
init_msgs += [create_msg("visualizer_clear", tokenizer.version),
|
190 |
create_msg("visualizer_append", events)]
|
191 |
yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
t = ct
|
209 |
-
events = []
|
210 |
|
211 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
212 |
mid = tokenizer.detokenize(mid_seq)
|
@@ -307,6 +305,10 @@ if __name__ == "__main__":
|
|
307 |
}
|
308 |
models = {}
|
309 |
if opt.device == "cuda":
|
|
|
|
|
|
|
|
|
310 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
311 |
torch.backends.cuda.enable_flash_sdp(True)
|
312 |
for name, (repo_id, path, config) in models_info.items():
|
@@ -315,7 +317,7 @@ if __name__ == "__main__":
|
|
315 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
316 |
state_dict = ckpt.get("state_dict", ckpt)
|
317 |
model.load_state_dict(state_dict, strict=False)
|
318 |
-
model.to(device="cpu", dtype=torch.
|
319 |
models[name] = model
|
320 |
|
321 |
load_javascript()
|
|
|
189 |
init_msgs += [create_msg("visualizer_clear", tokenizer.version),
|
190 |
create_msg("visualizer_append", events)]
|
191 |
yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
|
192 |
+
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
193 |
+
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
194 |
+
disable_channels=disable_channels, generator=generator)
|
195 |
+
events = []
|
196 |
+
t = time.time() + 1
|
197 |
+
for i, token_seq in enumerate(midi_generator):
|
198 |
+
token_seq = token_seq.tolist()
|
199 |
+
mid_seq.append(token_seq)
|
200 |
+
events.append(tokenizer.tokens2event(token_seq))
|
201 |
+
ct = time.time()
|
202 |
+
if ct - t > 0.5:
|
203 |
+
yield (mid_seq, continuation_state, None, None, seed,
|
204 |
+
send_msgs([create_msg("visualizer_append", events),
|
205 |
+
create_msg("progress", [i + 1, gen_events])]))
|
206 |
+
t = ct
|
207 |
+
events = []
|
|
|
|
|
208 |
|
209 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
210 |
mid = tokenizer.detokenize(mid_seq)
|
|
|
305 |
}
|
306 |
models = {}
|
307 |
if opt.device == "cuda":
|
308 |
+
torch.backends.cudnn.deterministic = True
|
309 |
+
torch.backends.cudnn.benchmark = False
|
310 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
311 |
+
torch.backends.cudnn.allow_tf32 = True
|
312 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
313 |
torch.backends.cuda.enable_flash_sdp(True)
|
314 |
for name, (repo_id, path, config) in models_info.items():
|
|
|
317 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
318 |
state_dict = ckpt.get("state_dict", ckpt)
|
319 |
model.load_state_dict(state_dict, strict=False)
|
320 |
+
model.to(device="cpu", dtype=torch.float32)
|
321 |
models[name] = model
|
322 |
|
323 |
load_javascript()
|