|
import subprocess |
|
import os |
|
import gradio as gr |
|
import json |
|
from utils import * |
|
from unidecode import unidecode |
|
from transformers import AutoTokenizer |
|
|
|
description = """ |
|
<div> |
|
<a style="display:inline-block" href='https://github.com/microsoft/muzic/tree/main/clamp'><img src='https://img.shields.io/github/stars/microsoft/muzic?style=social' /></a> |
|
<a style='display:inline-block' href='https://ai-muzic.github.io/clamp/'><img src='https://img.shields.io/badge/website-CLaMP-ff69b4.svg' /></a> |
|
<a style="display:inline-block" href="https://huggingface.co/datasets/sander-wood/wikimusictext"><img src="https://img.shields.io/badge/huggingface-dataset-ffcc66.svg"></a> |
|
<a style="display:inline-block" href="https://arxiv.org/pdf/2304.11029.pdf"><img src="https://img.shields.io/badge/arXiv-2304.11029-b31b1b.svg"></a> |
|
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/clamp_similar_music_recommendation?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a> |
|
</div> |
|
|
|
## ℹ️ How to use this demo? |
|
1. Select a music file in MusicXML (.mxl) format. |
|
2. Click "Submit" and wait for the result. |
|
3. It will return the most similar music score from the WikiMusictext dataset (1010 scores in total). |
|
|
|
## ❕Notice |
|
- The demo only supports MusicXML (.mxl) files. |
|
- The returned results include the title, artist, genre, description, and the score in ABC notation. |
|
- The genre and description may not be accurate, as they are collected from the web. |
|
- The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512. |
|
|
|
## 🎵👉🎵 Similar Music Recommendation |
|
A surprising capability of CLaMP is that it can also recommend similar music given a piece of music, even though it is not trained on this task. This is because CLaMP is trained to encode the semantic meaning of music, and thus it can capture the similarity between music pieces.We only use the music encoder to extract the music feature from the music query, and then calculate the similarity between the query and all the pieces of music in the library. |
|
|
|
""" |
|
|
|
CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512' |
|
QUERY_MODAL = 'music' |
|
KEY_MODAL = 'music' |
|
TOP_N = 1 |
|
TEXT_MODEL_NAME = 'distilroberta-base' |
|
TEXT_LENGTH = 128 |
|
device = torch.device("cpu") |
|
|
|
|
|
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME) |
|
music_length = model.config.max_length |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
patchilizer = MusicPatchilizer() |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) |
|
softmax = torch.nn.Softmax(dim=1) |
|
|
|
def compute_values(Q_e, K_e, t=1): |
|
""" |
|
Compute the values for the attention matrix |
|
|
|
Args: |
|
Q_e (torch.Tensor): Query embeddings |
|
K_e (torch.Tensor): Key embeddings |
|
t (float): Temperature for the softmax |
|
|
|
Returns: |
|
values (torch.Tensor): Values for the attention matrix |
|
""" |
|
|
|
Q_e = torch.nn.functional.normalize(Q_e, dim=1) |
|
K_e = torch.nn.functional.normalize(K_e, dim=1) |
|
|
|
|
|
logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t)) |
|
values = softmax(logits) |
|
return values.squeeze() |
|
|
|
|
|
def encoding_data(data, modal): |
|
""" |
|
Encode the data into ids |
|
|
|
Args: |
|
data (list): List of strings |
|
modal (str): "music" or "text" |
|
|
|
Returns: |
|
ids_list (list): List of ids |
|
""" |
|
ids_list = [] |
|
if modal=="music": |
|
for item in data: |
|
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True) |
|
ids_list.append(torch.tensor(patches).reshape(-1)) |
|
else: |
|
for item in data: |
|
text_encodings = tokenizer(item, |
|
return_tensors='pt', |
|
truncation=True, |
|
max_length=TEXT_LENGTH) |
|
ids_list.append(text_encodings['input_ids'].squeeze(0)) |
|
|
|
return ids_list |
|
|
|
|
|
def abc_filter(lines): |
|
""" |
|
Filter out the metadata from the abc file |
|
|
|
Args: |
|
lines (list): List of lines in the abc file |
|
|
|
Returns: |
|
music (str): Music string |
|
""" |
|
music = "" |
|
for line in lines: |
|
if line[:2] in ['A:', 'B:', 'C:', 'D:', 'F:', 'G', 'H:', 'N:', 'O:', 'R:', 'r:', 'S:', 'T:', 'W:', 'w:', 'X:', 'Z:'] \ |
|
or line=='\n' \ |
|
or (line.startswith('%') and not line.startswith('%%score')): |
|
continue |
|
else: |
|
if "%" in line and not line.startswith('%%score'): |
|
line = "%".join(line.split('%')[:-1]) |
|
music += line[:-1] + '\n' |
|
else: |
|
music += line + '\n' |
|
return music |
|
|
|
|
|
def load_music(filename): |
|
""" |
|
Load the music from the xml file |
|
|
|
Args: |
|
file (Union[str, bytes, BinaryIO, TextIO]): Input file object containing the xml file |
|
|
|
Returns: |
|
music (str): Music string |
|
""" |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
xml2abc_path = os.path.join(script_dir, 'xml2abc.py') |
|
|
|
|
|
p = subprocess.Popen(['python', xml2abc_path, '-m', '2', '-c', '6', '-x', filename], stdout=subprocess.PIPE) |
|
result = p.communicate()[0] |
|
output = result.decode('utf-8').replace('\r', '') |
|
music = unidecode(output).split('\n') |
|
music = abc_filter(music) |
|
|
|
return music |
|
|
|
|
|
def get_features(ids_list, modal): |
|
""" |
|
Get the features from the CLaMP model |
|
|
|
Args: |
|
ids_list (list): List of ids |
|
modal (str): "music" or "text" |
|
|
|
Returns: |
|
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) |
|
""" |
|
features_list = [] |
|
print("Extracting "+modal+" features...") |
|
with torch.no_grad(): |
|
for ids in tqdm(ids_list): |
|
ids = ids.unsqueeze(0) |
|
if modal=="text": |
|
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0) |
|
features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state'] |
|
features = model.avg_pooling(features, masks) |
|
features = model.text_proj(features) |
|
else: |
|
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0) |
|
features = model.music_enc(ids, masks)['last_hidden_state'] |
|
features = model.avg_pooling(features, masks) |
|
features = model.music_proj(features) |
|
|
|
features_list.append(features[0]) |
|
|
|
return torch.stack(features_list).to(device) |
|
|
|
|
|
def similar_music_recommendation(file): |
|
""" |
|
Recommend similar music |
|
|
|
Args: |
|
file (Union[str, bytes, BinaryIO, TextIO]): Input file object containing the xml file |
|
|
|
Returns: |
|
output (str): Output string |
|
""" |
|
query = load_music(file.name) |
|
print("\nQuery:\n"+ query) |
|
with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f: |
|
key_cache = torch.load(f) |
|
|
|
|
|
query_ids = encoding_data([query], QUERY_MODAL) |
|
query_feature = get_features(query_ids, QUERY_MODAL) |
|
|
|
key_filenames = key_cache["filenames"] |
|
key_features = key_cache["features"] |
|
|
|
|
|
values = compute_values(query_feature, key_features) |
|
idx = torch.argsort(values)[-1] |
|
filename = key_filenames[idx].split('/')[-1][:-4] |
|
|
|
with open("wikimusictext.json", 'r') as f: |
|
wikimusictext = json.load(f) |
|
|
|
for item in wikimusictext: |
|
if item['title']==filename: |
|
|
|
|
|
|
|
|
|
|
|
print("Title: " + item['title']) |
|
print("Artist: " + item['artist']) |
|
print("Genre: " + item['genre']) |
|
print("Description: " + item['text']) |
|
print("ABC notation:\n" + item['music']) |
|
return item["title"], item["artist"], item["genre"], item["text"], item["music"] |
|
|
|
input_file = gr.inputs.File(label="Upload MusicXML file") |
|
output_title = gr.outputs.Textbox(label="Title") |
|
output_artist = gr.outputs.Textbox(label="Artist") |
|
output_genre = gr.outputs.Textbox(label="Genre") |
|
output_description = gr.outputs.Textbox(label="Description") |
|
output_abc = gr.outputs.Textbox(label="ABC notation") |
|
gr.Interface(similar_music_recommendation, |
|
inputs=input_file, |
|
outputs=[output_title, output_artist, output_genre, output_description, output_abc], |
|
title="🗜️ CLaMP: Similar Music Recommendation", |
|
description=description).launch() |
|
|