Xixeo Kiryanovkd commited on
Commit
91029e3
0 Parent(s):

Duplicate from Mubert/Text-to-Music

Browse files

Co-authored-by: Kirill Kiryanov <Kiryanovkd@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +33 -0
  2. README.md +14 -0
  3. app.py +95 -0
  4. constants.py +7 -0
  5. requirements.txt +2 -0
  6. utils.py +50 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text To Music
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.6
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ duplicated_from: Mubert/Text-to-Music
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ import httpx
7
+ import json
8
+
9
+ from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
10
+
11
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
12
+ mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
13
+
14
+
15
+ def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
16
+ if loop:
17
+ mode = "loop"
18
+ else:
19
+ mode = "track"
20
+ r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
21
+ json={
22
+ "method": "RecordTrackTTM",
23
+ "params": {
24
+ "pat": pat,
25
+ "duration": duration,
26
+ "tags": tags,
27
+ "mode": mode
28
+ }
29
+ })
30
+
31
+ rdata = json.loads(r.text)
32
+ assert rdata['status'] == 1, rdata['error']['text']
33
+ trackurl = rdata['data']['tasks'][0]['download_link']
34
+
35
+ print('Generating track ', end='')
36
+ for i in range(maxit):
37
+ r = httpx.get(trackurl)
38
+ if r.status_code == 200:
39
+ return trackurl
40
+ time.sleep(1)
41
+
42
+
43
+ def generate_track_by_prompt(email, prompt, duration, loop=False):
44
+ try:
45
+ pat = get_pat(email)
46
+ _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
47
+ return get_track_by_tags(tags, pat, int(duration), loop=loop), "Success", ",".join(tags)
48
+ except Exception as e:
49
+ return None, str(e), ""
50
+
51
+
52
+ block = gr.Blocks()
53
+
54
+ with block:
55
+ gr.HTML(
56
+ """
57
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
58
+ <div
59
+ style="
60
+ display: inline-flex;
61
+ align-items: center;
62
+ gap: 0.8rem;
63
+ font-size: 1.75rem;
64
+ "
65
+ >
66
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
67
+ Mubert
68
+ </h1>
69
+ </div>
70
+ <p style="margin-bottom: 10px; font-size: 94%">
71
+ All music is generated by Mubert API – <a href="https://mubert.com" style="text-decoration: underline;" target="_blank">www.mubert.com</a>
72
+ </p>
73
+ </div>
74
+ """
75
+ )
76
+ with gr.Group():
77
+ with gr.Box():
78
+ email = gr.Textbox(label="email")
79
+ prompt = gr.Textbox(label="prompt")
80
+ duration = gr.Slider(label="duration (seconds)", value=30)
81
+ is_loop = gr.Checkbox(label="Generate loop")
82
+ out = gr.Audio()
83
+ result_msg = gr.Text(label="Result message")
84
+ tags = gr.Text(label="Tags")
85
+ btn = gr.Button("Submit").style(full_width=True)
86
+
87
+ btn.click(fn=generate_track_by_prompt, inputs=[email, prompt, duration, is_loop], outputs=[out, result_msg, tags])
88
+ gr.HTML('''
89
+ <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
90
+ <p>Demo by <a href="https://huggingface.co/Mubert" style="text-decoration: underline;" target="_blank">Mubert</a>
91
+ </p>
92
+ </div>
93
+ ''')
94
+
95
+ block.launch()
constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ MUBERT_TAGS_STRING = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'
4
+ MUBERT_TAGS = np.array(MUBERT_TAGS_STRING.split(','))
5
+ MUBERT_LICENSE = "ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5"
6
+ MUBERT_MODE = "loop"
7
+ MUBERT_TOKEN = "4951f6428e83172a4f39de05d5b3ab10d58560b8"
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ httpx
2
+ sentence-transformers
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import httpx
4
+
5
+ from constants import MUBERT_TAGS, MUBERT_LICENSE, MUBERT_MODE, MUBERT_TOKEN
6
+
7
+
8
+ def get_mubert_tags_embeddings(w2v_model):
9
+ return w2v_model.encode(MUBERT_TAGS)
10
+
11
+
12
+ def get_pat(email: str):
13
+ r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
14
+ json={
15
+ "method": "GetServiceAccess",
16
+ "params": {
17
+ "email": email,
18
+ "license": MUBERT_LICENSE,
19
+ "token": MUBERT_TOKEN,
20
+ "mode": MUBERT_MODE,
21
+ }
22
+ })
23
+
24
+ rdata = json.loads(r.text)
25
+ assert rdata['status'] == 1, "probably incorrect e-mail"
26
+ pat = rdata['data']['pat']
27
+ return pat
28
+
29
+
30
+ def find_similar(em, embeddings, method='cosine'):
31
+ scores = []
32
+ for ref in embeddings:
33
+ if method == 'cosine':
34
+ scores.append(1 - np.dot(ref, em) / (np.linalg.norm(ref) * np.linalg.norm(em)))
35
+ if method == 'norm':
36
+ scores.append(np.linalg.norm(ref - em))
37
+ return np.array(scores), np.argsort(scores)
38
+
39
+
40
+ def get_tags_for_prompts(w2v_model, mubert_tags_embeddings, prompts, top_n=3, debug=False):
41
+ prompts_embeddings = w2v_model.encode(prompts)
42
+ ret = []
43
+ for i, pe in enumerate(prompts_embeddings):
44
+ scores, idxs = find_similar(pe, mubert_tags_embeddings)
45
+ top_tags = MUBERT_TAGS[idxs[:top_n]]
46
+ top_prob = 1 - scores[idxs[:top_n]]
47
+ if debug:
48
+ print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
49
+ ret.append((prompts[i], list(top_tags)))
50
+ return ret