import os |
import torch |
import librosa |
import look2hear.models |
import soundfile as sf |
from tqdm.auto import tqdm |
import argparse |
import numpy as np |
import warnings |
warnings.filterwarnings("ignore") |
def load_audio(file_path): |
audio, samplerate = librosa.load(file_path, mono=False, sr=44100) |
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}') |
return torch.from_numpy(audio), samplerate |
def save_audio(file_path, audio, samplerate=44100): |
sf.write(file_path, audio.T, samplerate, subtype="PCM_16") |
def process_chunk(chunk): |
chunk = chunk.unsqueeze(0).cuda() |
with torch.no_grad(): |
return model(chunk).squeeze(0).squeeze(0).cpu() |
def _getWindowingArray(window_size, fade_size): |
fadein = torch.linspace(1, 1, fade_size) |
fadeout = torch.linspace(0, 0, fade_size) |
window = torch.ones(window_size) |
window[-fade_size:] *= fadeout |
window[:fade_size] *= fadein |
return window |
def dBgain(audio, volume_gain_dB): |
gain = 10 ** (volume_gain_dB / 20) |
gained_audio = audio * gain |
return gained_audio |
def main(input_wav, output_wav): |
os.environ['CUDA_VISIBLE_DEVICES'] = "0" |
global model |
model = look2hear.models.BaseModel.from_pretrain("/kaggle/working/Apollo/model/pytorch_model.bin", sr=44100, win=20, feature_dim=256, layer=6).cuda() |
test_data, samplerate = load_audio(input_wav) |
C = chunk_size * samplerate |
N = overlap |
step = C // N |
fade_size = 2 * 44100 |
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}") |
border = C - step |
if test_data.shape[1] > 2 * border and (border > 0): |
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect') |
windowingArray = _getWindowingArray(C, fade_size) |
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) |
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) |
i = 0 |
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False) |
while i < test_data.shape[1]: |
part = test_data[:, i:i + C] |
length = part.shape[-1] |
if length < C: |
if length > C // 2 + 1: |
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') |
else: |
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) |
out = process_chunk(part) |
window = windowingArray |
if i == 0: |
window[:fade_size] = 1 |
elif i + C >= test_data.shape[1]: |
window[-fade_size:] = 1 |
result[..., i:i+length] += out[..., :length] * window[..., :length] |
counter[..., i:i+length] += window[..., :length] |
i += step |
progress_bar.update(step) |
progress_bar.close() |
final_output = result / counter |
final_output = final_output.squeeze(0).numpy() |
np.nan_to_num(final_output, copy=False, nan=0.0) |
if test_data.shape[1] > 2 * border and (border > 0): |
final_output = final_output[..., border:-border] |
save_audio(output_wav, final_output, samplerate) |
print(f'Success! Output file saved as {output_wav}') |
model.cpu() |
del model |
torch.cuda.empty_cache() |
if __name__ == "__main__": |
parser = argparse.ArgumentParser(description="Audio Inference Script") |
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") |
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") |
parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=3) |
parser.add_argument("--overlap", type=int, help="Overlap", default=2) |
args = parser.parse_args() |
chunk_size = args.chunk_size |
overlap = args.overlap |
print(f'chunk_size = {chunk_size}, overlap = {overlap}') |
main(args.in_wav, args.out_wav) |