File size: 5,237 Bytes
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import torch
import librosa
import look2hear.models
import soundfile as sf
from tqdm.auto import tqdm
import argparse
import numpy as np
import yaml
from ml_collections import ConfigDict
#from omegaconf import OmegaConf

import warnings
warnings.filterwarnings("ignore")

def get_config(config_path):
    with open(config_path) as f:
        #config = OmegaConf.load(config_path)
        config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
        return config

def load_audio(file_path):
    audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
    print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
    #audio = dBgain(audio, -6)
    return torch.from_numpy(audio), samplerate

def save_audio(file_path, audio, samplerate=44100):
    #audio = dBgain(audio, +6)
    sf.write(file_path, audio.T, samplerate, subtype="PCM_16")

def process_chunk(chunk):
    chunk = chunk.unsqueeze(0).cpu()
    with torch.no_grad():
        return model(chunk).squeeze(0).squeeze(0).cpu()

def _getWindowingArray(window_size, fade_size):
    # IMPORTANT NOTE :
    # no fades here in the end, only removing the failed ending of the chunk
    fadein = torch.linspace(1, 1, fade_size)
    fadeout = torch.linspace(0, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] *= fadeout
    window[:fade_size] *= fadein
    return window

def dBgain(audio, volume_gain_dB):
    gain = 10 ** (volume_gain_dB / 20)
    gained_audio = audio * gain 
    return gained_audio


def main(input_wav, output_wav, ckpt_path):
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    global model
    feature_dim = config['model']['feature_dim']
    sr = config['model']['sr']
    win = config['model']['win']
    layer = config['model']['layer']
    model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cpu()

    test_data, samplerate = load_audio(input_wav)
    
    C = chunk_size * samplerate  # chunk_size seconds to samples
    N = overlap
    step = C // N
    fade_size = 3 * 44100 # 3 seconds
    print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
    
    border = C - step
    
    # handle mono inputs correctly
    if len(test_data.shape) == 1:
        test_data = test_data.unsqueeze(0) 

    # Pad the input if necessary
    if test_data.shape[1] > 2 * border and (border > 0):
        test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')

    windowingArray = _getWindowingArray(C, fade_size)

    result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
    counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)

    i = 0
    progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)

    while i < test_data.shape[1]:
        part = test_data[:, i:i + C]
        length = part.shape[-1]
        if length < C:
            if length > C // 2 + 1:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
            else:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)

        out = process_chunk(part)

        window = windowingArray
        if i == 0:  # First audio chunk, no fadein
            window[:fade_size] = 1
        elif i + C >= test_data.shape[1]:  # Last audio chunk, no fadeout
            window[-fade_size:] = 1

        result[..., i:i+length] += out[..., :length] * window[..., :length]
        counter[..., i:i+length] += window[..., :length]

        i += step
        progress_bar.update(step)

    progress_bar.close()

    final_output = result / counter
    final_output = final_output.squeeze(0).numpy()
    np.nan_to_num(final_output, copy=False, nan=0.0)

    # Remove padding if added earlier
    if test_data.shape[1] > 2 * border and (border > 0):
        final_output = final_output[..., border:-border]

    save_audio(output_wav, final_output, samplerate)
    print(f'Success! Output file saved as {output_wav}')

    # Memory clearing
    model.cpu()
    del model
    torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audio Inference Script")
    parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
    parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file", default="model/pytorch_model.bin")
    parser.add_argument("--config", type=str, help="Path to model config file", default="config/apollo.yaml")
    parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=10)
    parser.add_argument("--overlap", type=int, help="Overlap", default=2)
    args = parser.parse_args()
    
    ckpt_path = args.ckpt
    chunk_size = args.chunk_size
    overlap = args.overlap
    config = get_config(args.config)
    print(config['model'])
    print(f'ckpt_path = {ckpt_path}')
    #print(f'config = {config}')
    print(f'chunk_size = {chunk_size}, overlap = {overlap}')
    

    main(args.in_wav, args.out_wav, ckpt_path)