Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,083 Bytes
133ccd4 96007f4 8453f63 133ccd4 de46ee3 133ccd4 df03c6b 8453f63 df03c6b 23c274d df03c6b de46ee3 8453f63 de46ee3 23c274d de46ee3 23c274d de46ee3 23c274d de46ee3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import argparse
import glob
import os.path
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
import onnxruntime as rt
import tqdm
import json
from midi_synthesizer import synthesis
import TMIDIX
in_space = os.getenv("SYSTEM") == "spaces"
providers = ['CPUExecutionProvider']
#=================================================================================================
def generate(
start_tokens,
seq_len,
max_seq_len = 2048,
temperature = 0.9,
verbose=True,
return_prime=False,
):
out = torch.LongTensor([start_tokens])
st = len(start_tokens)
if verbose:
print("Generating sequence of max length:", seq_len)
for s in range(seq_len):
x = out[:, -max_seq_len:]
torch_in = x.tolist()[0]
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
filtered_logits = logits
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if verbose:
if s % 32 == 0:
print(s, '/', seq_len)
if return_prime:
return out[:, :]
else:
return out[:, st:]
#=================================================================================================
def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"{dir}/*.js")
javascript = ""
for path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
template_response_ori = gr.routes.templates.TemplateResponse
def template_response(*args, **kwargs):
res = template_response_ori(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
class JSMsgReceiver(gr.HTML):
def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)
def get_block_name(self) -> str:
return "html"
#=================================================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
parser.add_argument("--max-gen", type=int, default=1024, help="max")
opt = parser.parse_args()
providers = ['CPUExecutionProvider']
# session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=providers)
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
"Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
"Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
"[Open In Colab]"
"(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
" for faster execution and endless generation"
)
js_msg = JSMsgReceiver()
tab_select = gr.Variable(value=0)
app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True) |