cmagganas commited on
Commit
e385c7e
1 Parent(s): abed0b8

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +125 -0
model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import EncodecModel, AutoProcessor
2
+ import torch
3
+ from audiocraft.data.audio import audio_read, audio_write
4
+ import datetime
5
+ import IPython
6
+ import os
7
+ import julius
8
+
9
+ from transformers import EncodecModel
10
+ from typing import List, Optional, Tuple, Union
11
+
12
+ class EncodecNoQuantizeModel(EncodecModel):
13
+
14
+ def _encode_frame(
15
+ self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
16
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
17
+ """
18
+ Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
19
+ normalized. The padding mask is required to compute the correct scale.
20
+ """
21
+ length = input_values.shape[-1]
22
+ duration = length / self.config.sampling_rate
23
+
24
+ if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
25
+ raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
26
+
27
+ scale = None
28
+ if self.config.normalize:
29
+ # if the padding is non zero
30
+ input_values = input_values * padding_mask
31
+ mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
32
+ scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
33
+ input_values = input_values / scale
34
+
35
+ embeddings = self.encoder(input_values)
36
+ # codes = self.quantizer.encode(embeddings, bandwidth)
37
+ # codes = codes.transpose(0, 1)
38
+ return embeddings, scale
39
+
40
+ def _decode_frame(self, embeddings: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
41
+ # codes = codes.transpose(0, 1)
42
+ # embeddings = self.quantizer.decode(codes)
43
+ outputs = self.decoder(embeddings)
44
+ if scale is not None:
45
+ outputs = outputs * scale.view(-1, 1, 1)
46
+ return outputs
47
+
48
+
49
+ MODEL_SAMPLING_RATE = 48000
50
+
51
+ def load_model():
52
+ # load the model + processor (for pre-processing the audio)
53
+ model = EncodecNoQuantizeModel.from_pretrained("facebook/encodec_48khz").to("cuda:0")
54
+ processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
55
+
56
+ return model, processor
57
+
58
+ @torch.no_grad()
59
+ def invert_audio(
60
+ model, processor, input_audio_path, out_path,
61
+ normalize=True, flip_input=True, flip_output=False):
62
+
63
+ model.config.normalize = normalize
64
+
65
+ audio_sample_1, sampling_rate_1 = audio_read(input_audio_path)
66
+ if sampling_rate_1 != MODEL_SAMPLING_RATE:
67
+ audio_sample_1 = julius.resample_frac(audio_sample_1, sampling_rate_1, MODEL_SAMPLING_RATE)
68
+
69
+ # audio_sample [2, 9399305]
70
+ if flip_input:
71
+ audio_sample_1 = torch.flip(audio_sample_1, dims=(1,))
72
+
73
+ # pre-process the inputs
74
+ inputs_1 = processor(raw_audio=audio_sample_1, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
75
+ inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
76
+ inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")
77
+
78
+ # explicitly encode then decode the audio inputs
79
+ print("Encoding...")
80
+ encoder_outputs_1 = model.encode(
81
+ inputs_1["input_values"],
82
+ inputs_1["padding_mask"],
83
+ bandwidth=max(model.config.target_bandwidths))
84
+
85
+ # EMBEDDINGS (no quantized):
86
+ # encoder_outputs.audio_codes.shape
87
+ # [216, 1, 128, 150]
88
+
89
+ avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
90
+ # [1, 1, 128, 1]
91
+ avg_repeat = avg.repeat(
92
+ encoder_outputs_1.audio_codes.shape[0],
93
+ encoder_outputs_1.audio_codes.shape[1],
94
+ 1,
95
+ encoder_outputs_1.audio_codes.shape[3])
96
+ # [216, 1, 128, 150]
97
+ diff_repeat = encoder_outputs_1.audio_codes - avg_repeat
98
+
99
+ # TODO: power factor calculations kinda useless if we keep the factor one???
100
+ POWER_FACTOR = 1
101
+ max_abs_diff = torch.max(torch.abs(diff_repeat))
102
+ diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
103
+ latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power
104
+
105
+ # difference inversion done here!
106
+ latents = latents * -1.0
107
+
108
+ print("Decoding...")
109
+ audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]
110
+
111
+ # [1, 2, 10264800]
112
+ if flip_output:
113
+ audio_values = torch.flip(audio_values, dims=(2,))
114
+
115
+ output_dir = "/home/romainpaulusisep_gmail_com/data/outputs"
116
+ decoded_wav = audio_values.squeeze(0).to("cpu")
117
+
118
+ print("Saving output file...")
119
+ out_path_ = audio_write(
120
+ out_path,
121
+ sample_rate=MODEL_SAMPLING_RATE,
122
+ wav=decoded_wav,
123
+ normalize=False)
124
+
125
+ return out_path_