Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import EncodecModel, AutoProcessor
|
2 |
+
import torch
|
3 |
+
from audiocraft.data.audio import audio_read, audio_write
|
4 |
+
import datetime
|
5 |
+
import IPython
|
6 |
+
import os
|
7 |
+
import julius
|
8 |
+
|
9 |
+
from transformers import EncodecModel
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
class EncodecNoQuantizeModel(EncodecModel):
|
13 |
+
|
14 |
+
def _encode_frame(
|
15 |
+
self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
|
16 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
17 |
+
"""
|
18 |
+
Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
|
19 |
+
normalized. The padding mask is required to compute the correct scale.
|
20 |
+
"""
|
21 |
+
length = input_values.shape[-1]
|
22 |
+
duration = length / self.config.sampling_rate
|
23 |
+
|
24 |
+
if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
|
25 |
+
raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
|
26 |
+
|
27 |
+
scale = None
|
28 |
+
if self.config.normalize:
|
29 |
+
# if the padding is non zero
|
30 |
+
input_values = input_values * padding_mask
|
31 |
+
mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
|
32 |
+
scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
|
33 |
+
input_values = input_values / scale
|
34 |
+
|
35 |
+
embeddings = self.encoder(input_values)
|
36 |
+
# codes = self.quantizer.encode(embeddings, bandwidth)
|
37 |
+
# codes = codes.transpose(0, 1)
|
38 |
+
return embeddings, scale
|
39 |
+
|
40 |
+
def _decode_frame(self, embeddings: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
41 |
+
# codes = codes.transpose(0, 1)
|
42 |
+
# embeddings = self.quantizer.decode(codes)
|
43 |
+
outputs = self.decoder(embeddings)
|
44 |
+
if scale is not None:
|
45 |
+
outputs = outputs * scale.view(-1, 1, 1)
|
46 |
+
return outputs
|
47 |
+
|
48 |
+
|
49 |
+
MODEL_SAMPLING_RATE = 48000
|
50 |
+
|
51 |
+
def load_model():
|
52 |
+
# load the model + processor (for pre-processing the audio)
|
53 |
+
model = EncodecNoQuantizeModel.from_pretrained("facebook/encodec_48khz").to("cuda:0")
|
54 |
+
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
55 |
+
|
56 |
+
return model, processor
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def invert_audio(
|
60 |
+
model, processor, input_audio_path, out_path,
|
61 |
+
normalize=True, flip_input=True, flip_output=False):
|
62 |
+
|
63 |
+
model.config.normalize = normalize
|
64 |
+
|
65 |
+
audio_sample_1, sampling_rate_1 = audio_read(input_audio_path)
|
66 |
+
if sampling_rate_1 != MODEL_SAMPLING_RATE:
|
67 |
+
audio_sample_1 = julius.resample_frac(audio_sample_1, sampling_rate_1, MODEL_SAMPLING_RATE)
|
68 |
+
|
69 |
+
# audio_sample [2, 9399305]
|
70 |
+
if flip_input:
|
71 |
+
audio_sample_1 = torch.flip(audio_sample_1, dims=(1,))
|
72 |
+
|
73 |
+
# pre-process the inputs
|
74 |
+
inputs_1 = processor(raw_audio=audio_sample_1, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
|
75 |
+
inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
|
76 |
+
inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
|
77 |
+
|
78 |
+
# explicitly encode then decode the audio inputs
|
79 |
+
print("Encoding...")
|
80 |
+
encoder_outputs_1 = model.encode(
|
81 |
+
inputs_1["input_values"],
|
82 |
+
inputs_1["padding_mask"],
|
83 |
+
bandwidth=max(model.config.target_bandwidths))
|
84 |
+
|
85 |
+
# EMBEDDINGS (no quantized):
|
86 |
+
# encoder_outputs.audio_codes.shape
|
87 |
+
# [216, 1, 128, 150]
|
88 |
+
|
89 |
+
avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
|
90 |
+
# [1, 1, 128, 1]
|
91 |
+
avg_repeat = avg.repeat(
|
92 |
+
encoder_outputs_1.audio_codes.shape[0],
|
93 |
+
encoder_outputs_1.audio_codes.shape[1],
|
94 |
+
1,
|
95 |
+
encoder_outputs_1.audio_codes.shape[3])
|
96 |
+
# [216, 1, 128, 150]
|
97 |
+
diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
|
98 |
+
|
99 |
+
# TODO: power factor calculations kinda useless if we keep the factor one???
|
100 |
+
POWER_FACTOR = 1
|
101 |
+
max_abs_diff = torch.max(torch.abs(diff_repeat))
|
102 |
+
diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
|
103 |
+
latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
|
104 |
+
|
105 |
+
# difference inversion done here!
|
106 |
+
latents = latents * -1.0
|
107 |
+
|
108 |
+
print("Decoding...")
|
109 |
+
audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
|
110 |
+
|
111 |
+
# [1, 2, 10264800]
|
112 |
+
if flip_output:
|
113 |
+
audio_values = torch.flip(audio_values, dims=(2,))
|
114 |
+
|
115 |
+
output_dir = "/home/romainpaulusisep_gmail_com/data/outputs"
|
116 |
+
decoded_wav = audio_values.squeeze(0).to("cpu")
|
117 |
+
|
118 |
+
print("Saving output file...")
|
119 |
+
out_path_ = audio_write(
|
120 |
+
out_path,
|
121 |
+
sample_rate=MODEL_SAMPLING_RATE,
|
122 |
+
wav=decoded_wav,
|
123 |
+
normalize=False)
|
124 |
+
|
125 |
+
return out_path_
|