File size: 4,209 Bytes
e385c7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9017fab
e385c7e
 
 
 
9017fab
 
 
e385c7e
9017fab
e385c7e
9017fab
 
 
 
e385c7e
 
 
9017fab
e385c7e
 
 
 
 
 
 
 
 
 
 
 
 
9017fab
e385c7e
 
 
 
 
9017fab
e385c7e
 
 
 
 
 
 
 
9017fab
e385c7e
 
9017fab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from transformers import EncodecModel, AutoProcessor
import torch
from audiocraft.data.audio import audio_read, audio_write
import datetime
import IPython
import os
import julius

from transformers import EncodecModel
from typing import List, Optional, Tuple, Union

class EncodecNoQuantizeModel(EncodecModel):

    def _encode_frame(
        self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
        normalized. The padding mask is required to compute the correct scale.
        """
        length = input_values.shape[-1]
        duration = length / self.config.sampling_rate

        if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
            raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")

        scale = None
        if self.config.normalize:
            # if the padding is non zero
            input_values = input_values * padding_mask
            mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
            scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
            input_values = input_values / scale

        embeddings = self.encoder(input_values)
        # codes = self.quantizer.encode(embeddings, bandwidth)
        # codes = codes.transpose(0, 1)
        return embeddings, scale

    def _decode_frame(self, embeddings: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
        # codes = codes.transpose(0, 1)
        # embeddings = self.quantizer.decode(codes)
        outputs = self.decoder(embeddings)
        if scale is not None:
            outputs = outputs * scale.view(-1, 1, 1)
        return outputs
    

MODEL_SAMPLING_RATE = 48000

def load_model():
    # load the model + processor (for pre-processing the audio)
    model = EncodecNoQuantizeModel.from_pretrained("facebook/encodec_48khz").to("cuda:0")
    processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")

    return model, processor

@torch.no_grad()
def invert_audio(
        model, processor, input_audio, sampling_rate,
        normalize=True, flip_input=True, flip_output=False):

    model.config.normalize = normalize

    # Check and resample the input audio if necessary
    if sampling_rate != MODEL_SAMPLING_RATE:
        input_audio = julius.resample_frac(input_audio, sampling_rate, MODEL_SAMPLING_RATE)

    # Flip the audio if required
    if flip_input:
        input_audio = torch.flip(input_audio, dims=(1,))

    # Pre-process the inputs
    inputs_1 = processor(raw_audio=input_audio, sampling_rate=MODEL_SAMPLING_RATE, return_tensors="pt")
    inputs_1["input_values"] = inputs_1["input_values"].to("cuda:0")
    inputs_1["padding_mask"] = inputs_1["padding_mask"].to("cuda:0")

    # Explicitly encode then decode the audio inputs
    print("Encoding...")
    encoder_outputs_1 = model.encode(
        inputs_1["input_values"],
        inputs_1["padding_mask"],
        bandwidth=max(model.config.target_bandwidths))

    avg = torch.mean(encoder_outputs_1.audio_codes, (0, 3), True)
    avg_repeat = avg.repeat(
        encoder_outputs_1.audio_codes.shape[0],
        encoder_outputs_1.audio_codes.shape[1],
        1,
        encoder_outputs_1.audio_codes.shape[3])
    diff_repeat = encoder_outputs_1.audio_codes - avg_repeat

    POWER_FACTOR = 1
    max_abs_diff = torch.max(torch.abs(diff_repeat))
    diff_abs_power = ((torch.abs(diff_repeat) / max_abs_diff) ** POWER_FACTOR) * max_abs_diff
    latents = (diff_repeat >= 0) * diff_abs_power - (diff_repeat < 0) * diff_abs_power

    # Inversion of difference
    latents = latents * -1.0

    print("Decoding...")
    audio_values = model.decode(latents, encoder_outputs_1.audio_scales, inputs_1["padding_mask"])[0]

    if flip_output:
        audio_values = torch.flip(audio_values, dims=(2,))

    # Return the decoded audio tensor (or NumPy array, based on your audio_write function)
    decoded_wav = audio_values.squeeze(0).to("cpu")

    return decoded_wav