|
import re |
|
import os |
|
import torch |
|
import requests |
|
from tqdm import tqdm |
|
from unidecode import unidecode |
|
from transformers import AutoModel, AutoConfig, BertModel, PreTrainedModel |
|
|
|
|
|
PATCH_LENGTH = 64 |
|
PATCH_FEATURES = 98 |
|
|
|
class MusicPatchilizer: |
|
""" |
|
Class for converting music data to patches and vice-versa. |
|
|
|
Attributes: |
|
delimiters (tuple): A tuple of strings containing the delimiters used for splitting bars. |
|
regexPattern (str): A regular expression pattern for splitting bars. |
|
pad_id (int): The id of the padding token. |
|
mask_id (int): The id of the mask token. |
|
eos_id (int): The id of the end-of-sequence token. |
|
|
|
Methods: |
|
split_bars(body): Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. |
|
bar2patch(bar, patch_length): Encodes a single bar as a patch of specified length. |
|
patch2bar(patch): Converts a patch to a bar string. |
|
encode(music, music_length, patch_length=PATCH_LENGTH, add_eos_patch=False): Encodes the input music string as a list of patches. |
|
decode(patches): Decodes a sequence of patches into a music score. |
|
""" |
|
def __init__(self): |
|
|
|
self.delimiters = "|:", "::", ":|", "[|", "||", "|]", "|" |
|
|
|
self.regexPattern = '('+'|'.join(map(re.escape, self.delimiters))+')' |
|
|
|
self.pad_id = 0 |
|
self.mask_id = 96 |
|
self.eos_id = 97 |
|
|
|
def split_bars(self, body): |
|
""" |
|
Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. |
|
|
|
Args: |
|
body (str): A string containing the body of music to be split into bars. |
|
|
|
Returns: |
|
list: A list of strings containing the individual bars. |
|
""" |
|
body = "".join(body) |
|
bars = re.split(self.regexPattern, body) |
|
while("" in bars): |
|
bars.remove("") |
|
if bars[0] in self.delimiters: |
|
bars[1] = bars[0]+bars[1] |
|
bars = bars[1:] |
|
bars = [bars[i*2]+bars[i*2+1] for i in range(int(len(bars)/2))] |
|
|
|
return bars |
|
|
|
def bar2patch(self, bar, patch_length): |
|
""" |
|
Encodes a single bar as a patch of specified length. |
|
|
|
Args: |
|
bar (str): A string containing the bar to be encoded. |
|
patch_length (int): An integer indicating the length of the patch to be returned. |
|
|
|
Returns: |
|
list: A list of integer-encoded musical tokens. |
|
""" |
|
patch = [self.pad_id] * patch_length |
|
|
|
for i in range(min(patch_length, len(bar))): |
|
chr = bar[i] |
|
idx = ord(chr) |
|
if idx>=32 and idx<127: |
|
patch[i] = idx-31 |
|
|
|
if i+1<patch_length: |
|
patch[i+1] = self.eos_id |
|
|
|
return patch |
|
|
|
def patch2bar(self, patch): |
|
""" |
|
Converts a patch to a bar string. |
|
|
|
Args: |
|
patch (list): A list of integer-encoded musical tokens. |
|
|
|
Returns: |
|
str: A string containing the decoded bar. |
|
""" |
|
bar = "" |
|
|
|
for idx in patch: |
|
if idx>0 and idx<96: |
|
bar += chr(idx+31) |
|
else: |
|
break |
|
|
|
return bar |
|
|
|
def encode(self, music, music_length, patch_length=PATCH_LENGTH, add_eos_patch=False): |
|
""" |
|
Encodes the input music string as a list of patches. |
|
|
|
Args: |
|
music (str): A string containing the music to be encoded. |
|
music_length (int): An integer indicating the maximum number of patches to be returned. |
|
patch_length (int): An integer indicating the length of each patch. |
|
add_eos_patch (bool): A boolean indicating whether to add an extra patch consisting of all EOS tokens at the end of the encoded music. |
|
|
|
Returns: |
|
list: A list of integer-encoded patches. |
|
""" |
|
|
|
music = unidecode(music) |
|
lines = music.split('\n') |
|
try: |
|
lines.remove('') |
|
except: |
|
pass |
|
|
|
body = "" |
|
patches = [] |
|
|
|
|
|
for line in lines: |
|
|
|
if len(line)>1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')): |
|
|
|
if body!="": |
|
bars = self.split_bars(body) |
|
|
|
for bar in bars: |
|
|
|
patch = self.bar2patch(bar, patch_length) |
|
patches.append(patch) |
|
|
|
body = "" |
|
|
|
patch = self.bar2patch(line, patch_length) |
|
patches.append(patch) |
|
else: |
|
|
|
body += line |
|
|
|
if body!="": |
|
bars = self.split_bars(body) |
|
|
|
for bar in bars: |
|
|
|
patch = self.bar2patch(bar, patch_length) |
|
patches.append(patch) |
|
|
|
|
|
if add_eos_patch: |
|
eos_patch = [self.eos_id] * patch_length |
|
patches = patches + [eos_patch] |
|
|
|
return patches[:music_length] |
|
|
|
def decode(self, patches): |
|
""" |
|
Decodes a sequence of patches into a music score. |
|
|
|
Args: |
|
patches (list): A list of integer-encoded patches. |
|
|
|
Returns: |
|
str: A string containing the decoded music score. |
|
""" |
|
music = "" |
|
for patch in patches: |
|
music += self.patch2bar(patch)+'\n' |
|
|
|
return music |
|
|
|
|
|
class MusicEncoder(PreTrainedModel): |
|
""" |
|
MusicEncoder model for encoding music patches into a sequence of hidden states. |
|
|
|
Args: |
|
config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the configuration. |
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. |
|
|
|
Attributes: |
|
patch_embedding (:obj:`torch.nn.Linear`): A linear layer to convert the one-hot encoded patches to the hidden size of the model. |
|
enc (:obj:`BertModel`): The BERT model used to encode the patches. |
|
""" |
|
def __init__(self, config): |
|
super(MusicEncoder, self).__init__(config) |
|
self.patch_embedding = torch.nn.Linear(PATCH_LENGTH*PATCH_FEATURES, config.hidden_size) |
|
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) |
|
self.enc = BertModel(config=config) |
|
|
|
def forward(self, input_musics, music_masks): |
|
""" |
|
Args: |
|
input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): |
|
Tensor containing the integer-encoded music patches. |
|
music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): |
|
Tensor containing the attention masks for the music patches. |
|
|
|
Returns: |
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: |
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, music_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
""" |
|
|
|
input_musics = torch.nn.functional.one_hot(input_musics, num_classes=PATCH_FEATURES) |
|
|
|
|
|
input_musics = input_musics.reshape(len(input_musics), -1, PATCH_LENGTH*PATCH_FEATURES).type(torch.FloatTensor) |
|
|
|
|
|
input_musics = self.patch_embedding(input_musics.to(self.device)) |
|
|
|
|
|
output = self.enc(inputs_embeds=input_musics, attention_mask=music_masks.to(self.device)) |
|
|
|
return output |
|
|
|
|
|
class CLaMP(PreTrainedModel): |
|
""" |
|
CLaMP model for joint text and music encoding. |
|
|
|
Args: |
|
config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the configuration. |
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. |
|
text_model_name (:obj:`str`, `optional`, defaults to :obj:`"distilroberta-base"`): |
|
The name of the pre-trained text model to be used for text encoding. |
|
|
|
Attributes: |
|
text_enc (:obj:`AutoModel`): The pre-trained text model used for text encoding. |
|
text_proj (:obj:`torch.nn.Linear`): A linear layer to project the text encoding to the hidden size of the model. |
|
music_enc (:obj:`MusicEncoder`): The music encoder model used for music encoding. |
|
music_proj (:obj:`torch.nn.Linear`): A linear layer to project the music encoding to the hidden size of the model. |
|
""" |
|
def __init__(self, config, text_model_name="distilroberta-base"): |
|
super(CLaMP, self).__init__(config) |
|
self.text_enc = AutoModel.from_pretrained(text_model_name) |
|
self.text_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) |
|
torch.nn.init.normal_(self.text_proj.weight, std=0.02) |
|
|
|
self.music_enc = MusicEncoder(config=config) |
|
self.music_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) |
|
torch.nn.init.normal_(self.music_proj.weight, std=0.02) |
|
|
|
def forward(self, input_texts, text_masks, input_musics, music_masks): |
|
""" |
|
Args: |
|
input_texts (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): |
|
Tensor containing the integer-encoded text. |
|
text_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): |
|
Tensor containing the attention masks for the text. |
|
input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): |
|
Tensor containing the integer-encoded music patches. |
|
music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): |
|
Tensor containing the attention masks for the music patches. |
|
|
|
Returns: |
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: |
|
music_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): |
|
The music features extracted from the music encoder. |
|
text_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): |
|
The text features extracted from the text encoder. |
|
""" |
|
|
|
text_features = self.text_enc(input_texts.to(self.device), attention_mask=text_masks.to(self.device))['last_hidden_state'] |
|
text_features = self.avg_pooling(text_features, text_masks) |
|
text_features = self.text_proj(text_features) |
|
|
|
|
|
music_features = self.music_enc(input_musics, music_masks)['last_hidden_state'] |
|
music_features = self.avg_pooling(music_features, music_masks) |
|
music_features = self.music_proj(music_features) |
|
|
|
return music_features, text_features |
|
|
|
def avg_pooling(self, input_features, input_masks): |
|
""" |
|
Applies average pooling to the input features. |
|
|
|
Args: |
|
input_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_length, hidden_size)`): |
|
Tensor containing the input features. |
|
input_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`): |
|
Tensor containing the attention masks for the input features. |
|
|
|
Returns: |
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`: |
|
The pooled features. |
|
""" |
|
input_masks = input_masks.unsqueeze(-1).to(self.device) |
|
input_features = input_features * input_masks |
|
avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) |
|
|
|
return avg_pool |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
""" |
|
Instantiate a CLaMP model from a pre-trained model configuration. |
|
|
|
Args: |
|
pretrained_model_name_or_path (:obj:`str`): |
|
This can be either: |
|
"clamp-small-512" for the small CLaMP model with 512 max sequence length. |
|
"clamp-small-1024" for the small CLaMP model with 1024 max sequence length. |
|
|
|
Returns: |
|
:class:`~transformers.CLaMP`: The CLaMP model. |
|
""" |
|
model_dir = pretrained_model_name_or_path |
|
|
|
|
|
if not os.path.exists(model_dir): |
|
|
|
os.makedirs(model_dir) |
|
config_url = f"https://huggingface.co/{pretrained_model_name_or_path}/raw/main/config.json" |
|
model_url = f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/pytorch_model.bin" |
|
chunk_size = 1024 * 1024 |
|
|
|
|
|
with requests.get(config_url, stream=True) as r: |
|
r.raise_for_status() |
|
total_size = int(r.headers.get('content-length', 0)) |
|
with open(model_dir+"/config.json", 'wb') as f: |
|
with tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading config') as pbar: |
|
for chunk in r.iter_content(chunk_size=chunk_size): |
|
f.write(chunk) |
|
pbar.update(len(chunk)) |
|
|
|
|
|
with requests.get(model_url, stream=True) as r: |
|
r.raise_for_status() |
|
total_size = int(r.headers.get('content-length', 0)) |
|
with open(model_dir+"/pytorch_model.bin", 'wb') as f: |
|
with tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading model') as pbar: |
|
for chunk in r.iter_content(chunk_size=chunk_size): |
|
f.write(chunk) |
|
pbar.update(len(chunk)) |
|
|
|
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
model = cls(config) |
|
model.load_state_dict(torch.load(pretrained_model_name_or_path+str('/pytorch_model.bin'))) |
|
|
|
return model |