sunsetsobserver
commited on
Commit
·
ae3131d
1
Parent(s):
2171a21
Add generate only continuation
Browse files- gen_res/0.json +0 -1
- gen_res/0.mid +0 -0
- generate_on_one_track.py +105 -0
gen_res/0.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"ids": [[155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407, 155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165]]}
|
|
|
|
gen_res/0.mid
DELETED
Binary file (3.32 kB)
|
|
generate_on_one_track.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from pathlib import Path
|
3 |
+
from random import shuffle
|
4 |
+
|
5 |
+
from torch import Tensor, argmax
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.cuda import is_available as cuda_available, is_bf16_supported
|
8 |
+
from torch.backends.mps import is_available as mps_available
|
9 |
+
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
|
10 |
+
from transformers.trainer_utils import set_seed
|
11 |
+
from evaluate import load as load_metric
|
12 |
+
from miditok import REMI, TokenizerConfig
|
13 |
+
from miditok.pytorch_data import DatasetTok, DataCollator
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
# Our tokenizer's configuration
|
17 |
+
PITCH_RANGE = (21, 109)
|
18 |
+
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
|
19 |
+
NUM_VELOCITIES = 24
|
20 |
+
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
|
21 |
+
USE_CHORDS = False
|
22 |
+
USE_RESTS = False
|
23 |
+
USE_TEMPOS = True
|
24 |
+
USE_TIME_SIGNATURE = False
|
25 |
+
USE_PROGRAMS = False
|
26 |
+
NUM_TEMPOS = 32
|
27 |
+
TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo)
|
28 |
+
TOKENIZER_PARAMS = {
|
29 |
+
"pitch_range": PITCH_RANGE,
|
30 |
+
"beat_res": BEAT_RES,
|
31 |
+
"num_velocities": NUM_VELOCITIES,
|
32 |
+
"special_tokens": SPECIAL_TOKENS,
|
33 |
+
"use_chords": USE_CHORDS,
|
34 |
+
"use_rests": USE_RESTS,
|
35 |
+
"use_tempos": USE_TEMPOS,
|
36 |
+
"use_time_signatures": USE_TIME_SIGNATURE,
|
37 |
+
"use_programs": USE_PROGRAMS,
|
38 |
+
"num_tempos": NUM_TEMPOS,
|
39 |
+
"tempo_range": TEMPO_RANGE,
|
40 |
+
}
|
41 |
+
config = TokenizerConfig(**TOKENIZER_PARAMS)
|
42 |
+
|
43 |
+
# Seed
|
44 |
+
set_seed(777)
|
45 |
+
|
46 |
+
# Creates the tokenizer
|
47 |
+
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")
|
48 |
+
|
49 |
+
midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))
|
50 |
+
|
51 |
+
""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """
|
52 |
+
|
53 |
+
# Loads tokens and create data collator
|
54 |
+
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
|
55 |
+
dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
|
56 |
+
collator = DataCollator(
|
57 |
+
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Creates model using the correct configuration
|
61 |
+
model = MistralForCausalLM.from_pretrained("./runs")
|
62 |
+
|
63 |
+
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
|
64 |
+
|
65 |
+
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
|
66 |
+
generation_config = GenerationConfig(
|
67 |
+
max_new_tokens=512, # extends samples by 512 tokens
|
68 |
+
num_beams=1, # no beam search
|
69 |
+
do_sample=True, # but sample instead
|
70 |
+
temperature=0.9,
|
71 |
+
top_k=15,
|
72 |
+
top_p=0.95,
|
73 |
+
epsilon_cutoff=3e-4,
|
74 |
+
eta_cutoff=1e-3,
|
75 |
+
)
|
76 |
+
|
77 |
+
# Here the sequences are padded to the left, so that the last token along the time dimension
|
78 |
+
# is always the last token of each seq, allowing to efficiently generate by batch
|
79 |
+
collator.pad_on_left = True
|
80 |
+
collator.eos_token = None
|
81 |
+
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
|
82 |
+
model.eval()
|
83 |
+
count = 0
|
84 |
+
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
|
85 |
+
res = model.generate(
|
86 |
+
inputs=batch["input_ids"].to(model.device),
|
87 |
+
attention_mask=batch["attention_mask"].to(model.device),
|
88 |
+
generation_config=generation_config) # (N,T)
|
89 |
+
|
90 |
+
# Saves the generated music, as MIDI files and tokens (json)
|
91 |
+
for prompt, continuation in zip(batch["input_ids"], res):
|
92 |
+
# Generate the MIDI for the entire sequence (prompt + continuation)
|
93 |
+
midi = tokenizer.tokens_to_midi([deepcopy(continuation.tolist())])
|
94 |
+
|
95 |
+
# Set the track name to indicate it includes both the original and the continuation
|
96 |
+
midi.tracks[0].name = f'Original sample and continuation ({len(continuation)} tokens)'
|
97 |
+
|
98 |
+
# Dump the MIDI file for the combined prompt and continuation
|
99 |
+
midi.dump_midi(gen_results_path / f'{count}.mid')
|
100 |
+
|
101 |
+
# Optionally, save the tokens for the combined sequence
|
102 |
+
tokens = [continuation.tolist()] # This time, only saving the combined sequence
|
103 |
+
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
|
104 |
+
|
105 |
+
count += 1
|