Spaces:
No application file
No application file
Upload vqvae.py
Browse files
vqvae.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as t
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from jukebox.vqvae.encdec import Encoder, Decoder, assert_shape
|
6 |
+
from jukebox.vqvae.bottleneck import NoBottleneck, Bottleneck
|
7 |
+
from jukebox.utils.logger import average_metrics
|
8 |
+
from jukebox.utils.audio_utils import spectral_convergence, spectral_loss, multispectral_loss, audio_postprocess
|
9 |
+
|
10 |
+
def dont_update(params):
|
11 |
+
for param in params:
|
12 |
+
param.requires_grad = False
|
13 |
+
|
14 |
+
def update(params):
|
15 |
+
for param in params:
|
16 |
+
param.requires_grad = True
|
17 |
+
|
18 |
+
def calculate_strides(strides, downs):
|
19 |
+
return [stride ** down for stride, down in zip(strides, downs)]
|
20 |
+
|
21 |
+
def _loss_fn(loss_fn, x_target, x_pred, hps):
|
22 |
+
if loss_fn == 'l1':
|
23 |
+
return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1']
|
24 |
+
elif loss_fn == 'l2':
|
25 |
+
return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2']
|
26 |
+
elif loss_fn == 'linf':
|
27 |
+
residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
|
28 |
+
values, _ = t.topk(residual, hps.linf_k, dim=1)
|
29 |
+
return t.mean(values) / hps.bandwidth['l2']
|
30 |
+
elif loss_fn == 'lmix':
|
31 |
+
loss = 0.0
|
32 |
+
if hps.lmix_l1:
|
33 |
+
loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps)
|
34 |
+
if hps.lmix_l2:
|
35 |
+
loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps)
|
36 |
+
if hps.lmix_linf:
|
37 |
+
loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps)
|
38 |
+
return loss
|
39 |
+
else:
|
40 |
+
assert False, f"Unknown loss_fn {loss_fn}"
|
41 |
+
|
42 |
+
class VQVAE(nn.Module):
|
43 |
+
def __init__(self, input_shape, levels, downs_t, strides_t,
|
44 |
+
emb_width, l_bins, mu, commit, spectral, multispectral,
|
45 |
+
multipliers=None, use_bottleneck=True, **block_kwargs):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.sample_length = input_shape[0]
|
49 |
+
x_shape, x_channels = input_shape[:-1], input_shape[-1]
|
50 |
+
self.x_shape = x_shape
|
51 |
+
|
52 |
+
self.downsamples = calculate_strides(strides_t, downs_t)
|
53 |
+
self.hop_lengths = np.cumprod(self.downsamples)
|
54 |
+
self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)]
|
55 |
+
self.levels = levels
|
56 |
+
|
57 |
+
if multipliers is None:
|
58 |
+
self.multipliers = [1] * levels
|
59 |
+
else:
|
60 |
+
assert len(multipliers) == levels, "Invalid number of multipliers"
|
61 |
+
self.multipliers = multipliers
|
62 |
+
def _block_kwargs(level):
|
63 |
+
this_block_kwargs = dict(block_kwargs)
|
64 |
+
this_block_kwargs["width"] *= self.multipliers[level]
|
65 |
+
this_block_kwargs["depth"] *= self.multipliers[level]
|
66 |
+
return this_block_kwargs
|
67 |
+
|
68 |
+
encoder = lambda level: Encoder(x_channels, emb_width, level + 1,
|
69 |
+
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
|
70 |
+
decoder = lambda level: Decoder(x_channels, emb_width, level + 1,
|
71 |
+
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
|
72 |
+
self.encoders = nn.ModuleList()
|
73 |
+
self.decoders = nn.ModuleList()
|
74 |
+
for level in range(levels):
|
75 |
+
self.encoders.append(encoder(level))
|
76 |
+
self.decoders.append(decoder(level))
|
77 |
+
|
78 |
+
if use_bottleneck:
|
79 |
+
self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels)
|
80 |
+
else:
|
81 |
+
self.bottleneck = NoBottleneck(levels)
|
82 |
+
|
83 |
+
self.downs_t = downs_t
|
84 |
+
self.strides_t = strides_t
|
85 |
+
self.l_bins = l_bins
|
86 |
+
self.commit = commit
|
87 |
+
self.spectral = spectral
|
88 |
+
self.multispectral = multispectral
|
89 |
+
|
90 |
+
def preprocess(self, x):
|
91 |
+
# x: NTC [-1,1] -> NCT [-1,1]
|
92 |
+
assert len(x.shape) == 3
|
93 |
+
x = x.permute(0,2,1).float()
|
94 |
+
return x
|
95 |
+
|
96 |
+
def postprocess(self, x):
|
97 |
+
# x: NTC [-1,1] <- NCT [-1,1]
|
98 |
+
x = x.permute(0,2,1)
|
99 |
+
return x
|
100 |
+
|
101 |
+
def _decode(self, zs, start_level=0, end_level=None):
|
102 |
+
# Decode
|
103 |
+
if end_level is None:
|
104 |
+
end_level = self.levels
|
105 |
+
assert len(zs) == end_level - start_level
|
106 |
+
xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level)
|
107 |
+
assert len(xs_quantised) == end_level - start_level
|
108 |
+
|
109 |
+
# Use only lowest level
|
110 |
+
decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1]
|
111 |
+
x_out = decoder(x_quantised, all_levels=False)
|
112 |
+
x_out = self.postprocess(x_out)
|
113 |
+
return x_out
|
114 |
+
|
115 |
+
def decode(self, zs, start_level=0, end_level=None, bs_chunks=1):
|
116 |
+
z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs]
|
117 |
+
x_outs = []
|
118 |
+
for i in range(bs_chunks):
|
119 |
+
zs_i = [z_chunk[i] for z_chunk in z_chunks]
|
120 |
+
x_out = self._decode(zs_i, start_level=start_level, end_level=end_level)
|
121 |
+
x_outs.append(x_out)
|
122 |
+
return t.cat(x_outs, dim=0)
|
123 |
+
|
124 |
+
def _encode(self, x, start_level=0, end_level=None):
|
125 |
+
# Encode
|
126 |
+
if end_level is None:
|
127 |
+
end_level = self.levels
|
128 |
+
x_in = self.preprocess(x)
|
129 |
+
xs = []
|
130 |
+
for level in range(self.levels):
|
131 |
+
encoder = self.encoders[level]
|
132 |
+
x_out = encoder(x_in)
|
133 |
+
xs.append(x_out[-1])
|
134 |
+
zs = self.bottleneck.encode(xs)
|
135 |
+
return zs[start_level:end_level]
|
136 |
+
|
137 |
+
def encode(self, x, start_level=0, end_level=None, bs_chunks=1):
|
138 |
+
x_chunks = t.chunk(x, bs_chunks, dim=0)
|
139 |
+
zs_list = []
|
140 |
+
for x_i in x_chunks:
|
141 |
+
zs_i = self._encode(x_i, start_level=start_level, end_level=end_level)
|
142 |
+
zs_list.append(zs_i)
|
143 |
+
zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)]
|
144 |
+
return zs
|
145 |
+
|
146 |
+
def sample(self, n_samples):
|
147 |
+
zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes]
|
148 |
+
return self.decode(zs)
|
149 |
+
|
150 |
+
def forward(self, x, hps, loss_fn='l1'):
|
151 |
+
metrics = {}
|
152 |
+
|
153 |
+
N = x.shape[0]
|
154 |
+
|
155 |
+
# Encode/Decode
|
156 |
+
x_in = self.preprocess(x)
|
157 |
+
xs = []
|
158 |
+
for level in range(self.levels):
|
159 |
+
encoder = self.encoders[level]
|
160 |
+
x_out = encoder(x_in)
|
161 |
+
xs.append(x_out[-1])
|
162 |
+
|
163 |
+
zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs)
|
164 |
+
x_outs = []
|
165 |
+
for level in range(self.levels):
|
166 |
+
decoder = self.decoders[level]
|
167 |
+
x_out = decoder(xs_quantised[level:level+1], all_levels=False)
|
168 |
+
assert_shape(x_out, x_in.shape)
|
169 |
+
x_outs.append(x_out)
|
170 |
+
|
171 |
+
# Loss
|
172 |
+
def _spectral_loss(x_target, x_out, hps):
|
173 |
+
if hps.use_nonrelative_specloss:
|
174 |
+
sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
|
175 |
+
else:
|
176 |
+
sl = spectral_convergence(x_target, x_out, hps)
|
177 |
+
sl = t.mean(sl)
|
178 |
+
return sl
|
179 |
+
|
180 |
+
def _multispectral_loss(x_target, x_out, hps):
|
181 |
+
sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
|
182 |
+
sl = t.mean(sl)
|
183 |
+
return sl
|
184 |
+
|
185 |
+
recons_loss = t.zeros(()).to(x.device)
|
186 |
+
spec_loss = t.zeros(()).to(x.device)
|
187 |
+
multispec_loss = t.zeros(()).to(x.device)
|
188 |
+
x_target = audio_postprocess(x.float(), hps)
|
189 |
+
|
190 |
+
for level in reversed(range(self.levels)):
|
191 |
+
x_out = self.postprocess(x_outs[level])
|
192 |
+
x_out = audio_postprocess(x_out, hps)
|
193 |
+
this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps)
|
194 |
+
this_spec_loss = _spectral_loss(x_target, x_out, hps)
|
195 |
+
this_multispec_loss = _multispectral_loss(x_target, x_out, hps)
|
196 |
+
metrics[f'recons_loss_l{level + 1}'] = this_recons_loss
|
197 |
+
metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss
|
198 |
+
metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss
|
199 |
+
recons_loss += this_recons_loss
|
200 |
+
spec_loss += this_spec_loss
|
201 |
+
multispec_loss += this_multispec_loss
|
202 |
+
|
203 |
+
commit_loss = sum(commit_losses)
|
204 |
+
loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss
|
205 |
+
|
206 |
+
with t.no_grad():
|
207 |
+
sc = t.mean(spectral_convergence(x_target, x_out, hps))
|
208 |
+
l2_loss = _loss_fn("l2", x_target, x_out, hps)
|
209 |
+
l1_loss = _loss_fn("l1", x_target, x_out, hps)
|
210 |
+
linf_loss = _loss_fn("linf", x_target, x_out, hps)
|
211 |
+
|
212 |
+
quantiser_metrics = average_metrics(quantiser_metrics)
|
213 |
+
|
214 |
+
metrics.update(dict(
|
215 |
+
recons_loss=recons_loss,
|
216 |
+
spectral_loss=spec_loss,
|
217 |
+
multispectral_loss=multispec_loss,
|
218 |
+
spectral_convergence=sc,
|
219 |
+
l2_loss=l2_loss,
|
220 |
+
l1_loss=l1_loss,
|
221 |
+
linf_loss=linf_loss,
|
222 |
+
commit_loss=commit_loss,
|
223 |
+
**quantiser_metrics))
|
224 |
+
|
225 |
+
for key, val in metrics.items():
|
226 |
+
metrics[key] = val.detach()
|
227 |
+
|
228 |
+
return x_out, loss, metrics
|