TristanBehrens commited on
Commit
9774a25
1 Parent(s): 6a02221

Adds a notebook for generating music that you can listen to.

Browse files
Files changed (1) hide show
  1. colab_jsfakes_generation.ipynb +258 -0
colab_jsfakes_generation.ipynb ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "DWLOSBkp0A2U"
7
+ },
8
+ "source": [
9
+ "# GPT-2 for music.\n",
10
+ "\n",
11
+ "This notebook shows you how to generate music with GPT-2\n",
12
+ "\n",
13
+ "---\n",
14
+ "\n",
15
+ "## Install depencencies."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {
22
+ "id": "6J_AnhV8D5p6"
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "!pip install transformers\n",
27
+ "!pip install note_seq"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "metadata": {
33
+ "id": "RzhHhFll0JVl"
34
+ },
35
+ "source": [
36
+ "## Load the tokenizer and the model from 🤗 Hub."
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {
43
+ "id": "g3ih12FMD7bs"
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
48
+ "\n",
49
+ "tokenizer = AutoTokenizer.from_pretrained(\"TristanBehrens/js-fakes-4bars\")\n",
50
+ "model = AutoModelForCausalLM.from_pretrained(\"TristanBehrens/js-fakes-4bars\")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {
56
+ "id": "GxRBk--Q0P1q"
57
+ },
58
+ "source": [
59
+ "## How to generate."
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {
66
+ "id": "ZZSLX96ID7t8"
67
+ },
68
+ "outputs": [],
69
+ "source": [
70
+ "# Encode the conditioning tokens.\n",
71
+ "input_ids = tokenizer.encode(\"PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST=48 BAR_START NOTE_ON=60\", return_tensors=\"pt\")\n",
72
+ "print(input_ids)\n",
73
+ "\n",
74
+ "# Generate more tokens.\n",
75
+ "generated_ids = model.generate(input_ids, max_length=500)\n",
76
+ "generated_sequence = tokenizer.decode(generated_ids[0])\n",
77
+ "print(generated_sequence)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {
83
+ "id": "YfHXFugA0WdI"
84
+ },
85
+ "source": [
86
+ "## Convert the generated tokens to music that you can listen to."
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {
93
+ "id": "L3QMj8NyEBqs"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "import note_seq\n",
98
+ "\n",
99
+ "NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120\n",
100
+ "BAR_LENGTH_120BPM = 4.0 * 60 / 120\n",
101
+ "\n",
102
+ "def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):\n",
103
+ "\n",
104
+ " if isinstance(token_sequence, str):\n",
105
+ " token_sequence = token_sequence.split()\n",
106
+ "\n",
107
+ " note_sequence = empty_note_sequence()\n",
108
+ "\n",
109
+ " # Render all notes.\n",
110
+ " current_program = 1\n",
111
+ " current_is_drum = False\n",
112
+ " current_instrument = 0\n",
113
+ " track_count = 0\n",
114
+ " for token_index, token in enumerate(token_sequence):\n",
115
+ "\n",
116
+ " if token == \"PIECE_START\":\n",
117
+ " pass\n",
118
+ " elif token == \"PIECE_END\":\n",
119
+ " print(\"The end.\")\n",
120
+ " break\n",
121
+ " elif token == \"TRACK_START\":\n",
122
+ " current_bar_index = 0\n",
123
+ " track_count += 1\n",
124
+ " pass\n",
125
+ " elif token == \"TRACK_END\":\n",
126
+ " pass\n",
127
+ " elif token == \"KEYS_START\":\n",
128
+ " pass\n",
129
+ " elif token == \"KEYS_END\":\n",
130
+ " pass\n",
131
+ " elif token.startswith(\"KEY=\"):\n",
132
+ " pass\n",
133
+ " elif token.startswith(\"INST\"):\n",
134
+ " instrument = token.split(\"=\")[-1]\n",
135
+ " if instrument != \"DRUMS\" and use_program:\n",
136
+ " if instrument_mapper is not None:\n",
137
+ " if instrument in instrument_mapper:\n",
138
+ " instrument = instrument_mapper[instrument]\n",
139
+ " current_program = int(instrument)\n",
140
+ " current_instrument = track_count\n",
141
+ " current_is_drum = False\n",
142
+ " if instrument == \"DRUMS\" and use_drums:\n",
143
+ " current_instrument = 0\n",
144
+ " current_program = 0\n",
145
+ " current_is_drum = True\n",
146
+ " elif token == \"BAR_START\":\n",
147
+ " current_time = current_bar_index * BAR_LENGTH_120BPM\n",
148
+ " current_notes = {}\n",
149
+ " elif token == \"BAR_END\":\n",
150
+ " current_bar_index += 1\n",
151
+ " pass\n",
152
+ " elif token.startswith(\"NOTE_ON\"):\n",
153
+ " pitch = int(token.split(\"=\")[-1])\n",
154
+ " note = note_sequence.notes.add()\n",
155
+ " note.start_time = current_time\n",
156
+ " note.end_time = current_time + 4 * NOTE_LENGTH_16TH_120BPM\n",
157
+ " note.pitch = pitch\n",
158
+ " note.instrument = current_instrument\n",
159
+ " note.program = current_program\n",
160
+ " note.velocity = 80\n",
161
+ " note.is_drum = current_is_drum\n",
162
+ " current_notes[pitch] = note\n",
163
+ " elif token.startswith(\"NOTE_OFF\"):\n",
164
+ " pitch = int(token.split(\"=\")[-1])\n",
165
+ " if pitch in current_notes:\n",
166
+ " note = current_notes[pitch]\n",
167
+ " note.end_time = current_time\n",
168
+ " elif token.startswith(\"TIME_DELTA\"):\n",
169
+ " delta = float(token.split(\"=\")[-1]) * NOTE_LENGTH_16TH_120BPM\n",
170
+ " current_time += delta\n",
171
+ " elif token.startswith(\"DENSITY=\"):\n",
172
+ " pass\n",
173
+ " elif token == \"[PAD]\":\n",
174
+ " pass\n",
175
+ " else:\n",
176
+ " #print(f\"Ignored token {token}.\")\n",
177
+ " pass\n",
178
+ "\n",
179
+ " # Make the instruments right.\n",
180
+ " instruments_drums = []\n",
181
+ " for note in note_sequence.notes:\n",
182
+ " pair = [note.program, note.is_drum]\n",
183
+ " if pair not in instruments_drums:\n",
184
+ " instruments_drums += [pair]\n",
185
+ " note.instrument = instruments_drums.index(pair)\n",
186
+ "\n",
187
+ " if only_piano:\n",
188
+ " for note in note_sequence.notes:\n",
189
+ " if not note.is_drum:\n",
190
+ " note.instrument = 0\n",
191
+ " note.program = 0\n",
192
+ "\n",
193
+ " return note_sequence\n",
194
+ "\n",
195
+ "def empty_note_sequence(qpm=120.0, total_time=0.0):\n",
196
+ " note_sequence = note_seq.protobuf.music_pb2.NoteSequence()\n",
197
+ " note_sequence.tempos.add().qpm = qpm\n",
198
+ " note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ\n",
199
+ " note_sequence.total_time = total_time\n",
200
+ " return note_sequence"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {
207
+ "id": "ZYpukydNESDF"
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "input_ids = tokenizer.encode(\"PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST=48 BAR_START NOTE_ON=61\", return_tensors=\"pt\")\n",
212
+ "generated_ids = model.generate(input_ids, max_length=500, temperature=1.0)\n",
213
+ "generated_sequence = tokenizer.decode(generated_ids[0])\n",
214
+ "\n",
215
+ "note_sequence = token_sequence_to_note_sequence(generated_sequence)\n",
216
+ "\n",
217
+ "synth = note_seq.midi_synth.synthesize\n",
218
+ "note_seq.plot_sequence(note_sequence)\n",
219
+ "note_seq.play_sequence(note_sequence, synth)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {
225
+ "id": "d1x6HeF90kkO"
226
+ },
227
+ "source": [
228
+ "# Thank you!"
229
+ ]
230
+ }
231
+ ],
232
+ "metadata": {
233
+ "colab": {
234
+ "collapsed_sections": [],
235
+ "name": "colab_jsfakes_generation.ipynb",
236
+ "provenance": []
237
+ },
238
+ "kernelspec": {
239
+ "display_name": "Python 3 (ipykernel)",
240
+ "language": "python",
241
+ "name": "python3"
242
+ },
243
+ "language_info": {
244
+ "codemirror_mode": {
245
+ "name": "ipython",
246
+ "version": 3
247
+ },
248
+ "file_extension": ".py",
249
+ "mimetype": "text/x-python",
250
+ "name": "python",
251
+ "nbconvert_exporter": "python",
252
+ "pygments_lexer": "ipython3",
253
+ "version": "3.9.7"
254
+ }
255
+ },
256
+ "nbformat": 4,
257
+ "nbformat_minor": 1
258
+ }