Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -8,14 +8,14 @@ import datetime
|
|
8 |
from pytz import timezone
|
9 |
|
10 |
import torch
|
11 |
-
import torch.nn.functional as F
|
12 |
|
13 |
import gradio as gr
|
14 |
|
15 |
-
from
|
|
|
16 |
import tqdm
|
17 |
|
18 |
-
|
19 |
import TMIDIX
|
20 |
|
21 |
import matplotlib.pyplot as plt
|
@@ -154,23 +154,22 @@ if __name__ == "__main__":
|
|
154 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
155 |
opt = parser.parse_args()
|
156 |
|
157 |
-
|
158 |
|
159 |
print('Loading model...')
|
160 |
|
161 |
-
SEQ_LEN =
|
|
|
162 |
|
163 |
# instantiate the model
|
164 |
|
165 |
model = TransformerWrapper(
|
166 |
-
num_tokens=
|
167 |
-
max_seq_len=SEQ_LEN,
|
168 |
-
attn_layers=Decoder(dim=1024, depth=
|
169 |
-
|
170 |
-
|
171 |
-
model = AutoregressiveWrapper(model)
|
172 |
-
|
173 |
-
model = torch.nn.DataParallel(model)
|
174 |
|
175 |
model.cpu()
|
176 |
print('=' * 70)
|
@@ -190,14 +189,14 @@ if __name__ == "__main__":
|
|
190 |
load_javascript()
|
191 |
app = gr.Blocks()
|
192 |
with app:
|
193 |
-
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>
|
|
|
194 |
gr.Markdown(
|
195 |
-
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.
|
196 |
-
"
|
197 |
-
"Check out [
|
198 |
-
"Special thanks go out to [SkyTNT](https://github.com/SkyTNT/midi-model) for fantastic FluidSynth Synthesizer and MIDI Visualizer code\n\n"
|
199 |
"[Open In Colab]"
|
200 |
-
"(https://colab.research.google.com/github/asigalov61/
|
201 |
" for faster execution and endless generation"
|
202 |
)
|
203 |
gr.Markdown("## Upload your MIDI")
|
@@ -206,7 +205,8 @@ if __name__ == "__main__":
|
|
206 |
input_num_tokens = gr.Slider(16, 512, value=256, label="Number of Tokens", info="Number of tokens to generate")
|
207 |
|
208 |
run_btn = gr.Button("generate", variant="primary")
|
209 |
-
|
|
|
210 |
|
211 |
output_midi_title = gr.Textbox(label="Output MIDI title")
|
212 |
output_midi_summary = gr.Textbox(label="Output MIDI summary")
|
|
|
8 |
from pytz import timezone
|
9 |
|
10 |
import torch
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
|
14 |
+
from x_transformer_1_23_2 import *
|
15 |
+
import random
|
16 |
import tqdm
|
17 |
|
18 |
+
inport midi_to_colab_audio
|
19 |
import TMIDIX
|
20 |
|
21 |
import matplotlib.pyplot as plt
|
|
|
154 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
155 |
opt = parser.parse_args()
|
156 |
|
157 |
+
soundfont = ["SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"]
|
158 |
|
159 |
print('Loading model...')
|
160 |
|
161 |
+
SEQ_LEN = 8192 # Models seq len
|
162 |
+
PAD_IDX = 385 # Models pad index
|
163 |
|
164 |
# instantiate the model
|
165 |
|
166 |
model = TransformerWrapper(
|
167 |
+
num_tokens = PAD_IDX+1,
|
168 |
+
max_seq_len = SEQ_LEN,
|
169 |
+
attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
|
170 |
+
)
|
171 |
+
|
172 |
+
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
|
|
|
|
173 |
|
174 |
model.cpu()
|
175 |
print('=' * 70)
|
|
|
189 |
load_javascript()
|
190 |
app = gr.Blocks()
|
191 |
with app:
|
192 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate Drums Transformer</h1>")
|
193 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate unique drums track for any MIDI</h1>")
|
194 |
gr.Markdown(
|
195 |
+
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Ultimate-Drums-Transformer&style=flat)\n\n"
|
196 |
+
"SOTA pure drums transformer which is capable of drums track generation for any source composition\n\n"
|
197 |
+
"Check out [Ultimate Drums Transformer](https://github.com/asigalov61/Ultimate-Drums-Transformer) on GitHub!\n\n"
|
|
|
198 |
"[Open In Colab]"
|
199 |
+
"(https://colab.research.google.com/github/asigalov61/Ultimate-Drums-Transformer/blob/main/Ultimate_Drums_Transformer.ipynb)"
|
200 |
" for faster execution and endless generation"
|
201 |
)
|
202 |
gr.Markdown("## Upload your MIDI")
|
|
|
205 |
input_num_tokens = gr.Slider(16, 512, value=256, label="Number of Tokens", info="Number of tokens to generate")
|
206 |
|
207 |
run_btn = gr.Button("generate", variant="primary")
|
208 |
+
|
209 |
+
gr.Markdown("## Generation results")
|
210 |
|
211 |
output_midi_title = gr.Textbox(label="Output MIDI title")
|
212 |
output_midi_summary = gr.Textbox(label="Output MIDI summary")
|