ShiromiyaGamer
commited on
Commit
•
b95d9b4
1
Parent(s):
30d9e05
Upload inference.py
Browse files- apollo/inference.py +124 -0
apollo/inference.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import look2hear.models
|
5 |
+
import soundfile as sf
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
def load_audio(file_path):
|
14 |
+
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
|
15 |
+
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
|
16 |
+
#audio = dBgain(audio, -6)
|
17 |
+
return torch.from_numpy(audio), samplerate
|
18 |
+
|
19 |
+
def save_audio(file_path, audio, samplerate=44100):
|
20 |
+
#audio = dBgain(audio, +6)
|
21 |
+
sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
|
22 |
+
|
23 |
+
def process_chunk(chunk):
|
24 |
+
chunk = chunk.unsqueeze(0).cuda()
|
25 |
+
with torch.no_grad():
|
26 |
+
return model(chunk).squeeze(0).squeeze(0).cpu()
|
27 |
+
|
28 |
+
def _getWindowingArray(window_size, fade_size):
|
29 |
+
# no fades here in the end, removing the failed ending of the chunk
|
30 |
+
fadein = torch.linspace(1, 1, fade_size)
|
31 |
+
fadeout = torch.linspace(0, 0, fade_size)
|
32 |
+
window = torch.ones(window_size)
|
33 |
+
window[-fade_size:] *= fadeout
|
34 |
+
window[:fade_size] *= fadein
|
35 |
+
return window
|
36 |
+
|
37 |
+
def dBgain(audio, volume_gain_dB):
|
38 |
+
gain = 10 ** (volume_gain_dB / 20)
|
39 |
+
gained_audio = audio * gain
|
40 |
+
return gained_audio
|
41 |
+
|
42 |
+
|
43 |
+
def main(input_wav, output_wav):
|
44 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
45 |
+
|
46 |
+
global model
|
47 |
+
model = look2hear.models.BaseModel.from_pretrain("/kaggle/working/Apollo/model/pytorch_model.bin", sr=44100, win=20, feature_dim=256, layer=6).cuda()
|
48 |
+
|
49 |
+
test_data, samplerate = load_audio(input_wav)
|
50 |
+
|
51 |
+
C = chunk_size * samplerate # chunk_size seconds to samples
|
52 |
+
N = overlap
|
53 |
+
step = C // N
|
54 |
+
fade_size = 2 * 44100 # 2 seconds
|
55 |
+
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
|
56 |
+
|
57 |
+
border = C - step
|
58 |
+
|
59 |
+
# Pad the input if necessary
|
60 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
61 |
+
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
|
62 |
+
|
63 |
+
windowingArray = _getWindowingArray(C, fade_size)
|
64 |
+
|
65 |
+
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
66 |
+
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
|
67 |
+
|
68 |
+
i = 0
|
69 |
+
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
|
70 |
+
|
71 |
+
while i < test_data.shape[1]:
|
72 |
+
part = test_data[:, i:i + C]
|
73 |
+
length = part.shape[-1]
|
74 |
+
if length < C:
|
75 |
+
if length > C // 2 + 1:
|
76 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
77 |
+
else:
|
78 |
+
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
79 |
+
|
80 |
+
out = process_chunk(part)
|
81 |
+
|
82 |
+
window = windowingArray
|
83 |
+
if i == 0: # First audio chunk, no fadein
|
84 |
+
window[:fade_size] = 1
|
85 |
+
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
|
86 |
+
window[-fade_size:] = 1
|
87 |
+
|
88 |
+
result[..., i:i+length] += out[..., :length] * window[..., :length]
|
89 |
+
counter[..., i:i+length] += window[..., :length]
|
90 |
+
|
91 |
+
i += step
|
92 |
+
progress_bar.update(step)
|
93 |
+
|
94 |
+
progress_bar.close()
|
95 |
+
|
96 |
+
final_output = result / counter
|
97 |
+
final_output = final_output.squeeze(0).numpy()
|
98 |
+
np.nan_to_num(final_output, copy=False, nan=0.0)
|
99 |
+
|
100 |
+
# Remove padding if added earlier
|
101 |
+
if test_data.shape[1] > 2 * border and (border > 0):
|
102 |
+
final_output = final_output[..., border:-border]
|
103 |
+
|
104 |
+
save_audio(output_wav, final_output, samplerate)
|
105 |
+
print(f'Success! Output file saved as {output_wav}')
|
106 |
+
|
107 |
+
# Memory clearing
|
108 |
+
model.cpu()
|
109 |
+
del model
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
parser = argparse.ArgumentParser(description="Audio Inference Script")
|
114 |
+
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
|
115 |
+
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
|
116 |
+
parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=3)
|
117 |
+
parser.add_argument("--overlap", type=int, help="Overlap", default=2)
|
118 |
+
args = parser.parse_args()
|
119 |
+
|
120 |
+
chunk_size = args.chunk_size
|
121 |
+
overlap = args.overlap
|
122 |
+
print(f'chunk_size = {chunk_size}, overlap = {overlap}')
|
123 |
+
|
124 |
+
main(args.in_wav, args.out_wav)
|