File size: 4,248 Bytes
b95d9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import torch
import librosa
import look2hear.models
import soundfile as sf
from tqdm.auto import tqdm
import argparse
import numpy as np
import warnings
warnings.filterwarnings("ignore")
def load_audio(file_path):
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
#audio = dBgain(audio, -6)
return torch.from_numpy(audio), samplerate
def save_audio(file_path, audio, samplerate=44100):
#audio = dBgain(audio, +6)
sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
def process_chunk(chunk):
chunk = chunk.unsqueeze(0).cuda()
with torch.no_grad():
return model(chunk).squeeze(0).squeeze(0).cpu()
def _getWindowingArray(window_size, fade_size):
# no fades here in the end, removing the failed ending of the chunk
fadein = torch.linspace(1, 1, fade_size)
fadeout = torch.linspace(0, 0, fade_size)
window = torch.ones(window_size)
window[-fade_size:] *= fadeout
window[:fade_size] *= fadein
return window
def dBgain(audio, volume_gain_dB):
gain = 10 ** (volume_gain_dB / 20)
gained_audio = audio * gain
return gained_audio
def main(input_wav, output_wav):
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
global model
model = look2hear.models.BaseModel.from_pretrain("/kaggle/working/Apollo/model/pytorch_model.bin", sr=44100, win=20, feature_dim=256, layer=6).cuda()
test_data, samplerate = load_audio(input_wav)
C = chunk_size * samplerate # chunk_size seconds to samples
N = overlap
step = C // N
fade_size = 2 * 44100 # 2 seconds
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
border = C - step
# Pad the input if necessary
if test_data.shape[1] > 2 * border and (border > 0):
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
windowingArray = _getWindowingArray(C, fade_size)
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
i = 0
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
while i < test_data.shape[1]:
part = test_data[:, i:i + C]
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
out = process_chunk(part)
window = windowingArray
if i == 0: # First audio chunk, no fadein
window[:fade_size] = 1
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
window[-fade_size:] = 1
result[..., i:i+length] += out[..., :length] * window[..., :length]
counter[..., i:i+length] += window[..., :length]
i += step
progress_bar.update(step)
progress_bar.close()
final_output = result / counter
final_output = final_output.squeeze(0).numpy()
np.nan_to_num(final_output, copy=False, nan=0.0)
# Remove padding if added earlier
if test_data.shape[1] > 2 * border and (border > 0):
final_output = final_output[..., border:-border]
save_audio(output_wav, final_output, samplerate)
print(f'Success! Output file saved as {output_wav}')
# Memory clearing
model.cpu()
del model
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio Inference Script")
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=3)
parser.add_argument("--overlap", type=int, help="Overlap", default=2)
args = parser.parse_args()
chunk_size = args.chunk_size
overlap = args.overlap
print(f'chunk_size = {chunk_size}, overlap = {overlap}')
main(args.in_wav, args.out_wav)
|