asigalov61
commited on
Commit
•
29991e1
1
Parent(s):
0e630bf
Update app.py
Browse files
app.py
CHANGED
@@ -24,28 +24,82 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
24 |
|
25 |
# =================================================================================================
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@torch.no_grad()
|
28 |
-
def GenerateMIDI(
|
29 |
print('=' * 70)
|
30 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
31 |
start_time = time.time()
|
32 |
|
|
|
|
|
|
|
33 |
print('-' * 70)
|
34 |
-
print('
|
35 |
-
print('Req
|
36 |
-
print('Drums:', idrums)
|
37 |
print('-' * 70)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
|
50 |
print('Selected Improv sequence:')
|
51 |
print(start_tokens)
|
@@ -219,7 +273,7 @@ if __name__ == "__main__":
|
|
219 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
220 |
|
221 |
|
222 |
-
run_event = run_btn.click(
|
223 |
[output_midi_title, output_midi_summary, output_audio, output_plot, output_midi])
|
224 |
|
225 |
app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
24 |
|
25 |
# =================================================================================================
|
26 |
|
27 |
+
def generate_drums(notes_times,
|
28 |
+
max_drums_limit = 8,
|
29 |
+
num_memory_tokens = 4096,
|
30 |
+
temperature=0.9):
|
31 |
+
|
32 |
+
x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cpu')
|
33 |
+
|
34 |
+
o = 128
|
35 |
+
|
36 |
+
ncount = 0
|
37 |
+
|
38 |
+
while o > 127 and ncount < max_drums_limit:
|
39 |
+
with ctx:
|
40 |
+
out = model.generate(x[-num_memory_tokens:],
|
41 |
+
1,
|
42 |
+
temperature=temperature,
|
43 |
+
return_prime=False,
|
44 |
+
verbose=False)
|
45 |
+
|
46 |
+
o = out.tolist()[0][0]
|
47 |
+
|
48 |
+
if 256 <= o < 384:
|
49 |
+
ncount += 1
|
50 |
+
|
51 |
+
if o > 127:
|
52 |
+
x = torch.cat((x, out), 1)
|
53 |
+
|
54 |
+
return x.tolist()[0][len(notes_times):]
|
55 |
+
|
56 |
+
# =================================================================================================
|
57 |
+
|
58 |
@torch.no_grad()
|
59 |
+
def GenerateMIDI(input_midi, input_num_tokens):
|
60 |
print('=' * 70)
|
61 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
62 |
start_time = time.time()
|
63 |
|
64 |
+
fn = os.path.basename(input_midi)
|
65 |
+
fn1 = fn.split('.')[0]
|
66 |
+
|
67 |
print('-' * 70)
|
68 |
+
print('Input file name:', fn)
|
69 |
+
print('Req num tok:', input_num_tokens)
|
|
|
70 |
print('-' * 70)
|
71 |
|
72 |
+
#===============================================================================
|
73 |
+
# Raw single-track ms score
|
74 |
+
|
75 |
+
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
76 |
+
|
77 |
+
#===============================================================================
|
78 |
+
# Enhanced score notes
|
79 |
+
|
80 |
+
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
81 |
+
|
82 |
+
#=======================================================
|
83 |
+
# PRE-PROCESSING
|
84 |
+
|
85 |
+
#===============================================================================
|
86 |
+
# Augmented enhanced score notes
|
87 |
+
|
88 |
+
escore_notes = [e for e in escore_notes if e[3] != 9]
|
89 |
+
|
90 |
+
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes)
|
91 |
+
|
92 |
+
patches = TMIDIX.patch_list_from_enhanced_score_notes(escore_notes)
|
93 |
+
|
94 |
+
dscore = TMIDIX.delta_score_notes(escore_notes, compress_timings=True, even_timings=True)
|
95 |
+
|
96 |
+
cscore = TMIDIX.chordify_score([d[1:] for d in dscore])
|
97 |
+
|
98 |
+
cscore_melody = [c[0] for c in cscore]
|
99 |
+
|
100 |
+
comp_times = [0] + [t[1] for t in dscore if t[1] != 0]
|
101 |
|
102 |
+
#===============================================================================
|
103 |
|
104 |
print('Selected Improv sequence:')
|
105 |
print(start_tokens)
|
|
|
273 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
274 |
|
275 |
|
276 |
+
run_event = run_btn.click(GenerateDrums, [input_midi, input_num_tokens],
|
277 |
[output_midi_title, output_midi_summary, output_audio, output_plot, output_midi])
|
278 |
|
279 |
app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|