asigalov61 commited on
Commit
29991e1
1 Parent(s): 0e630bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -14
app.py CHANGED
@@ -24,28 +24,82 @@ in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
  # =================================================================================================
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @torch.no_grad()
28
- def GenerateMIDI(num_tok, idrums, iinstr):
29
  print('=' * 70)
30
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
31
  start_time = time.time()
32
 
 
 
 
33
  print('-' * 70)
34
- print('Req num tok:', num_tok)
35
- print('Req instr:', iinstr)
36
- print('Drums:', idrums)
37
  print('-' * 70)
38
 
39
- if idrums:
40
- drums = 3074
41
- else:
42
- drums = 3073
43
-
44
- instruments_list = ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", 'Drums',
45
- "Choir", "Organ"]
46
- first_note_instrument_number = instruments_list.index(iinstr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- start_tokens = [3087, drums, 3075 + first_note_instrument_number]
49
 
50
  print('Selected Improv sequence:')
51
  print(start_tokens)
@@ -219,7 +273,7 @@ if __name__ == "__main__":
219
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
220
 
221
 
222
- run_event = run_btn.click(GenerateMIDI, [input_midi, input_num_tokens],
223
  [output_midi_title, output_midi_summary, output_audio, output_plot, output_midi])
224
 
225
  app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
24
 
25
  # =================================================================================================
26
 
27
+ def generate_drums(notes_times,
28
+ max_drums_limit = 8,
29
+ num_memory_tokens = 4096,
30
+ temperature=0.9):
31
+
32
+ x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cpu')
33
+
34
+ o = 128
35
+
36
+ ncount = 0
37
+
38
+ while o > 127 and ncount < max_drums_limit:
39
+ with ctx:
40
+ out = model.generate(x[-num_memory_tokens:],
41
+ 1,
42
+ temperature=temperature,
43
+ return_prime=False,
44
+ verbose=False)
45
+
46
+ o = out.tolist()[0][0]
47
+
48
+ if 256 <= o < 384:
49
+ ncount += 1
50
+
51
+ if o > 127:
52
+ x = torch.cat((x, out), 1)
53
+
54
+ return x.tolist()[0][len(notes_times):]
55
+
56
+ # =================================================================================================
57
+
58
  @torch.no_grad()
59
+ def GenerateMIDI(input_midi, input_num_tokens):
60
  print('=' * 70)
61
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
62
  start_time = time.time()
63
 
64
+ fn = os.path.basename(input_midi)
65
+ fn1 = fn.split('.')[0]
66
+
67
  print('-' * 70)
68
+ print('Input file name:', fn)
69
+ print('Req num tok:', input_num_tokens)
 
70
  print('-' * 70)
71
 
72
+ #===============================================================================
73
+ # Raw single-track ms score
74
+
75
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
76
+
77
+ #===============================================================================
78
+ # Enhanced score notes
79
+
80
+ escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
81
+
82
+ #=======================================================
83
+ # PRE-PROCESSING
84
+
85
+ #===============================================================================
86
+ # Augmented enhanced score notes
87
+
88
+ escore_notes = [e for e in escore_notes if e[3] != 9]
89
+
90
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes)
91
+
92
+ patches = TMIDIX.patch_list_from_enhanced_score_notes(escore_notes)
93
+
94
+ dscore = TMIDIX.delta_score_notes(escore_notes, compress_timings=True, even_timings=True)
95
+
96
+ cscore = TMIDIX.chordify_score([d[1:] for d in dscore])
97
+
98
+ cscore_melody = [c[0] for c in cscore]
99
+
100
+ comp_times = [0] + [t[1] for t in dscore if t[1] != 0]
101
 
102
+ #===============================================================================
103
 
104
  print('Selected Improv sequence:')
105
  print(start_tokens)
 
273
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
274
 
275
 
276
+ run_event = run_btn.click(GenerateDrums, [input_midi, input_num_tokens],
277
  [output_midi_title, output_midi_summary, output_audio, output_plot, output_midi])
278
 
279
  app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)