Files changed (1) hide show
  1. vqvae.py +228 -0
vqvae.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as t
3
+ import torch.nn as nn
4
+
5
+ from jukebox.vqvae.encdec import Encoder, Decoder, assert_shape
6
+ from jukebox.vqvae.bottleneck import NoBottleneck, Bottleneck
7
+ from jukebox.utils.logger import average_metrics
8
+ from jukebox.utils.audio_utils import spectral_convergence, spectral_loss, multispectral_loss, audio_postprocess
9
+
10
+ def dont_update(params):
11
+ for param in params:
12
+ param.requires_grad = False
13
+
14
+ def update(params):
15
+ for param in params:
16
+ param.requires_grad = True
17
+
18
+ def calculate_strides(strides, downs):
19
+ return [stride ** down for stride, down in zip(strides, downs)]
20
+
21
+ def _loss_fn(loss_fn, x_target, x_pred, hps):
22
+ if loss_fn == 'l1':
23
+ return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1']
24
+ elif loss_fn == 'l2':
25
+ return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2']
26
+ elif loss_fn == 'linf':
27
+ residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
28
+ values, _ = t.topk(residual, hps.linf_k, dim=1)
29
+ return t.mean(values) / hps.bandwidth['l2']
30
+ elif loss_fn == 'lmix':
31
+ loss = 0.0
32
+ if hps.lmix_l1:
33
+ loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps)
34
+ if hps.lmix_l2:
35
+ loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps)
36
+ if hps.lmix_linf:
37
+ loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps)
38
+ return loss
39
+ else:
40
+ assert False, f"Unknown loss_fn {loss_fn}"
41
+
42
+ class VQVAE(nn.Module):
43
+ def __init__(self, input_shape, levels, downs_t, strides_t,
44
+ emb_width, l_bins, mu, commit, spectral, multispectral,
45
+ multipliers=None, use_bottleneck=True, **block_kwargs):
46
+ super().__init__()
47
+
48
+ self.sample_length = input_shape[0]
49
+ x_shape, x_channels = input_shape[:-1], input_shape[-1]
50
+ self.x_shape = x_shape
51
+
52
+ self.downsamples = calculate_strides(strides_t, downs_t)
53
+ self.hop_lengths = np.cumprod(self.downsamples)
54
+ self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)]
55
+ self.levels = levels
56
+
57
+ if multipliers is None:
58
+ self.multipliers = [1] * levels
59
+ else:
60
+ assert len(multipliers) == levels, "Invalid number of multipliers"
61
+ self.multipliers = multipliers
62
+ def _block_kwargs(level):
63
+ this_block_kwargs = dict(block_kwargs)
64
+ this_block_kwargs["width"] *= self.multipliers[level]
65
+ this_block_kwargs["depth"] *= self.multipliers[level]
66
+ return this_block_kwargs
67
+
68
+ encoder = lambda level: Encoder(x_channels, emb_width, level + 1,
69
+ downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
70
+ decoder = lambda level: Decoder(x_channels, emb_width, level + 1,
71
+ downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
72
+ self.encoders = nn.ModuleList()
73
+ self.decoders = nn.ModuleList()
74
+ for level in range(levels):
75
+ self.encoders.append(encoder(level))
76
+ self.decoders.append(decoder(level))
77
+
78
+ if use_bottleneck:
79
+ self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels)
80
+ else:
81
+ self.bottleneck = NoBottleneck(levels)
82
+
83
+ self.downs_t = downs_t
84
+ self.strides_t = strides_t
85
+ self.l_bins = l_bins
86
+ self.commit = commit
87
+ self.spectral = spectral
88
+ self.multispectral = multispectral
89
+
90
+ def preprocess(self, x):
91
+ # x: NTC [-1,1] -> NCT [-1,1]
92
+ assert len(x.shape) == 3
93
+ x = x.permute(0,2,1).float()
94
+ return x
95
+
96
+ def postprocess(self, x):
97
+ # x: NTC [-1,1] <- NCT [-1,1]
98
+ x = x.permute(0,2,1)
99
+ return x
100
+
101
+ def _decode(self, zs, start_level=0, end_level=None):
102
+ # Decode
103
+ if end_level is None:
104
+ end_level = self.levels
105
+ assert len(zs) == end_level - start_level
106
+ xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level)
107
+ assert len(xs_quantised) == end_level - start_level
108
+
109
+ # Use only lowest level
110
+ decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1]
111
+ x_out = decoder(x_quantised, all_levels=False)
112
+ x_out = self.postprocess(x_out)
113
+ return x_out
114
+
115
+ def decode(self, zs, start_level=0, end_level=None, bs_chunks=1):
116
+ z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs]
117
+ x_outs = []
118
+ for i in range(bs_chunks):
119
+ zs_i = [z_chunk[i] for z_chunk in z_chunks]
120
+ x_out = self._decode(zs_i, start_level=start_level, end_level=end_level)
121
+ x_outs.append(x_out)
122
+ return t.cat(x_outs, dim=0)
123
+
124
+ def _encode(self, x, start_level=0, end_level=None):
125
+ # Encode
126
+ if end_level is None:
127
+ end_level = self.levels
128
+ x_in = self.preprocess(x)
129
+ xs = []
130
+ for level in range(self.levels):
131
+ encoder = self.encoders[level]
132
+ x_out = encoder(x_in)
133
+ xs.append(x_out[-1])
134
+ zs = self.bottleneck.encode(xs)
135
+ return zs[start_level:end_level]
136
+
137
+ def encode(self, x, start_level=0, end_level=None, bs_chunks=1):
138
+ x_chunks = t.chunk(x, bs_chunks, dim=0)
139
+ zs_list = []
140
+ for x_i in x_chunks:
141
+ zs_i = self._encode(x_i, start_level=start_level, end_level=end_level)
142
+ zs_list.append(zs_i)
143
+ zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)]
144
+ return zs
145
+
146
+ def sample(self, n_samples):
147
+ zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes]
148
+ return self.decode(zs)
149
+
150
+ def forward(self, x, hps, loss_fn='l1'):
151
+ metrics = {}
152
+
153
+ N = x.shape[0]
154
+
155
+ # Encode/Decode
156
+ x_in = self.preprocess(x)
157
+ xs = []
158
+ for level in range(self.levels):
159
+ encoder = self.encoders[level]
160
+ x_out = encoder(x_in)
161
+ xs.append(x_out[-1])
162
+
163
+ zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs)
164
+ x_outs = []
165
+ for level in range(self.levels):
166
+ decoder = self.decoders[level]
167
+ x_out = decoder(xs_quantised[level:level+1], all_levels=False)
168
+ assert_shape(x_out, x_in.shape)
169
+ x_outs.append(x_out)
170
+
171
+ # Loss
172
+ def _spectral_loss(x_target, x_out, hps):
173
+ if hps.use_nonrelative_specloss:
174
+ sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
175
+ else:
176
+ sl = spectral_convergence(x_target, x_out, hps)
177
+ sl = t.mean(sl)
178
+ return sl
179
+
180
+ def _multispectral_loss(x_target, x_out, hps):
181
+ sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
182
+ sl = t.mean(sl)
183
+ return sl
184
+
185
+ recons_loss = t.zeros(()).to(x.device)
186
+ spec_loss = t.zeros(()).to(x.device)
187
+ multispec_loss = t.zeros(()).to(x.device)
188
+ x_target = audio_postprocess(x.float(), hps)
189
+
190
+ for level in reversed(range(self.levels)):
191
+ x_out = self.postprocess(x_outs[level])
192
+ x_out = audio_postprocess(x_out, hps)
193
+ this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps)
194
+ this_spec_loss = _spectral_loss(x_target, x_out, hps)
195
+ this_multispec_loss = _multispectral_loss(x_target, x_out, hps)
196
+ metrics[f'recons_loss_l{level + 1}'] = this_recons_loss
197
+ metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss
198
+ metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss
199
+ recons_loss += this_recons_loss
200
+ spec_loss += this_spec_loss
201
+ multispec_loss += this_multispec_loss
202
+
203
+ commit_loss = sum(commit_losses)
204
+ loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss
205
+
206
+ with t.no_grad():
207
+ sc = t.mean(spectral_convergence(x_target, x_out, hps))
208
+ l2_loss = _loss_fn("l2", x_target, x_out, hps)
209
+ l1_loss = _loss_fn("l1", x_target, x_out, hps)
210
+ linf_loss = _loss_fn("linf", x_target, x_out, hps)
211
+
212
+ quantiser_metrics = average_metrics(quantiser_metrics)
213
+
214
+ metrics.update(dict(
215
+ recons_loss=recons_loss,
216
+ spectral_loss=spec_loss,
217
+ multispectral_loss=multispec_loss,
218
+ spectral_convergence=sc,
219
+ l2_loss=l2_loss,
220
+ l1_loss=l1_loss,
221
+ linf_loss=linf_loss,
222
+ commit_loss=commit_loss,
223
+ **quantiser_metrics))
224
+
225
+ for key, val in metrics.items():
226
+ metrics[key] = val.detach()
227
+
228
+ return x_out, loss, metrics