import argparse import glob import json import os.path import time import datetime from pytz import timezone import torch import gradio as gr from x_transformer_1_23_2 import * import random import tqdm inport midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt in_space = os.getenv("SYSTEM") == "spaces" # ================================================================================================= def generate_drums(notes_times, max_drums_limit = 8, num_memory_tokens = 4096, temperature=0.9): x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cpu') o = 128 ncount = 0 while o > 127 and ncount < max_drums_limit: with ctx: out = model.generate(x[-num_memory_tokens:], 1, temperature=temperature, return_prime=False, verbose=False) o = out.tolist()[0][0] if 256 <= o < 384: ncount += 1 if o > 127: x = torch.cat((x, out), 1) return x.tolist()[0][len(notes_times):] # ================================================================================================= @torch.no_grad() def GenerateMIDI(input_midi, input_num_tokens): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = time.time() fn = os.path.basename(input_midi) fn1 = fn.split('.')[0] print('-' * 70) print('Input file name:', fn) print('Req num tok:', input_num_tokens) print('-' * 70) #=============================================================================== # Raw single-track ms score raw_score = TMIDIX.midi2single_track_ms_score(input_midi) #=============================================================================== # Enhanced score notes escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] #======================================================= # PRE-PROCESSING #=============================================================================== # Augmented enhanced score notes escore_notes = [e for e in escore_notes if e[3] != 9] escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes) patches = TMIDIX.patch_list_from_enhanced_score_notes(escore_notes) dscore = TMIDIX.delta_score_notes(escore_notes, compress_timings=True, even_timings=True) cscore = TMIDIX.chordify_score([d[1:] for d in dscore]) cscore_melody = [c[0] for c in cscore] comp_times = [0] + [t[1] for t in dscore if t[1] != 0] #=============================================================================== print('Selected Improv sequence:') print(start_tokens) print('-' * 70) output_signature = 'Allegro Music Transformer' output_file_name = 'Allegro-Music-Transformer-Music-Composition' track_name = 'Project Los Angeles' list_of_MIDI_patches = [0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0] number_of_ticks_per_quarter = 500 text_encoding = 'ISO-8859-1' output_header = [number_of_ticks_per_quarter, [['track_name', 0, bytes(output_signature, text_encoding)]]] patch_list = [['patch_change', 0, 0, list_of_MIDI_patches[0]], ['patch_change', 0, 1, list_of_MIDI_patches[1]], ['patch_change', 0, 2, list_of_MIDI_patches[2]], ['patch_change', 0, 3, list_of_MIDI_patches[3]], ['patch_change', 0, 4, list_of_MIDI_patches[4]], ['patch_change', 0, 5, list_of_MIDI_patches[5]], ['patch_change', 0, 6, list_of_MIDI_patches[6]], ['patch_change', 0, 7, list_of_MIDI_patches[7]], ['patch_change', 0, 8, list_of_MIDI_patches[8]], ['patch_change', 0, 9, list_of_MIDI_patches[9]], ['patch_change', 0, 10, list_of_MIDI_patches[10]], ['patch_change', 0, 11, list_of_MIDI_patches[11]], ['patch_change', 0, 12, list_of_MIDI_patches[12]], ['patch_change', 0, 13, list_of_MIDI_patches[13]], ['patch_change', 0, 14, list_of_MIDI_patches[14]], ['patch_change', 0, 15, list_of_MIDI_patches[15]], ['track_name', 0, bytes(track_name, text_encoding)]] output = output_header + [patch_list] yield output, None, None, [create_msg("visualizer_clear", None)] outy = start_tokens ctime = 0 dur = 0 vel = 90 pitch = 0 channel = 0 for i in range(max(1, min(512, num_tok))): inp = torch.LongTensor([outy]).cpu() with ctx: out = model.module.generate(inp, 1, temperature=0.9, return_prime=False, verbose=False) out0 = out[0].tolist() outy.extend(out0) ss1 = out0[0] if 0 < ss1 < 256: ctime += ss1 * 8 if 256 <= ss1 < 1280: dur = ((ss1 - 256) // 8) * 32 vel = (((ss1 - 256) % 8) + 1) * 15 if 1280 <= ss1 < 2816: channel = (ss1 - 1280) // 128 pitch = (ss1 - 1280) % 128 event = ['note', ctime, dur, channel, pitch, vel] output[-1].append(event) yield output, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, num_tok])] midi_data = TMIDIX.score2midi(output, text_encoding) with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f: f.write(midi_data) audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2') print('Sample INTs', outy[:16]) print('-' * 70) print('Last generated MIDI event', output[2][-1]) print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (time.time() - start_time), 'sec') yield output, "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [ create_msg("visualizer_end", None)] # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", default=False, help="share gradio app") parser.add_argument("--port", type=int, default=7860, help="gradio server port") opt = parser.parse_args() soundfont = ["SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"] print('Loading model...') SEQ_LEN = 8192 # Models seq len PAD_IDX = 385 # Models pad index # instantiate the model model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True) ) model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) model.cpu() print('=' * 70) print('Loading model checkpoint...') model.load_state_dict( torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_8134_steps_0.3745_loss_0.8736_acc.pth', map_location='cpu')) print('=' * 70) model.eval() ctx = torch.amp.autocast(device_type='cpu', dtype=torch.bfloat16) print('Done!') print('=' * 70) load_javascript() app = gr.Blocks() with app: gr.Markdown("

Ultimate Drums Transformer

") gr.Markdown("

Generate unique drums track for any MIDI

") gr.Markdown( "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Ultimate-Drums-Transformer&style=flat)\n\n" "SOTA pure drums transformer which is capable of drums track generation for any source composition\n\n" "Check out [Ultimate Drums Transformer](https://github.com/asigalov61/Ultimate-Drums-Transformer) on GitHub!\n\n" "[Open In Colab]" "(https://colab.research.google.com/github/asigalov61/Ultimate-Drums-Transformer/blob/main/Ultimate_Drums_Transformer.ipynb)" " for faster execution and endless generation" ) gr.Markdown("## Upload your MIDI") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath") input_num_tokens = gr.Slider(16, 512, value=256, label="Number of Tokens", info="Number of tokens to generate") run_btn = gr.Button("generate", variant="primary") gr.Markdown("## Generation results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Output MIDI summary") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(GenerateDrums, [input_midi, input_num_tokens], [output_midi_title, output_midi_summary, output_audio, output_plot, output_midi]) app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)