asigalov61
commited on
Commit
•
3fea55e
1
Parent(s):
0731944
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,10 @@ import glob
|
|
3 |
import json
|
4 |
import os.path
|
5 |
|
|
|
|
|
|
|
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
|
@@ -16,13 +20,8 @@ import TMIDIX
|
|
16 |
|
17 |
import matplotlib.pyplot as plt
|
18 |
|
19 |
-
import time
|
20 |
-
import datetime
|
21 |
-
from pytz import timezone
|
22 |
-
|
23 |
in_space = os.getenv("SYSTEM") == "spaces"
|
24 |
|
25 |
-
|
26 |
# =================================================================================================
|
27 |
|
28 |
@torch.no_grad()
|
@@ -85,13 +84,15 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
85 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
86 |
|
87 |
outy = start_tokens
|
|
|
88 |
ctime = 0
|
89 |
dur = 0
|
90 |
-
vel =
|
91 |
pitch = 0
|
92 |
channel = 0
|
93 |
|
94 |
for i in range(max(1, min(512, num_tok))):
|
|
|
95 |
inp = torch.LongTensor([outy]).cpu()
|
96 |
|
97 |
out = model.module.generate(inp,
|
@@ -102,7 +103,8 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
102 |
|
103 |
out0 = out[0].tolist()
|
104 |
outy.extend(out0)
|
105 |
-
|
|
|
106 |
|
107 |
if 0 < ss1 < 256:
|
108 |
ctime += ss1 * 8
|
|
|
3 |
import json
|
4 |
import os.path
|
5 |
|
6 |
+
import time
|
7 |
+
import datetime
|
8 |
+
from pytz import timezone
|
9 |
+
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
|
|
|
20 |
|
21 |
import matplotlib.pyplot as plt
|
22 |
|
|
|
|
|
|
|
|
|
23 |
in_space = os.getenv("SYSTEM") == "spaces"
|
24 |
|
|
|
25 |
# =================================================================================================
|
26 |
|
27 |
@torch.no_grad()
|
|
|
84 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
85 |
|
86 |
outy = start_tokens
|
87 |
+
|
88 |
ctime = 0
|
89 |
dur = 0
|
90 |
+
vel = 90
|
91 |
pitch = 0
|
92 |
channel = 0
|
93 |
|
94 |
for i in range(max(1, min(512, num_tok))):
|
95 |
+
|
96 |
inp = torch.LongTensor([outy]).cpu()
|
97 |
|
98 |
out = model.module.generate(inp,
|
|
|
103 |
|
104 |
out0 = out[0].tolist()
|
105 |
outy.extend(out0)
|
106 |
+
|
107 |
+
ss1 = out0[0]
|
108 |
|
109 |
if 0 < ss1 < 256:
|
110 |
ctime += ss1 * 8
|