ShiromiyaGamer commited on
Commit
b95d9b4
1 Parent(s): 30d9e05

Upload inference.py

Browse files
Files changed (1) hide show
  1. apollo/inference.py +124 -0
apollo/inference.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import look2hear.models
5
+ import soundfile as sf
6
+ from tqdm.auto import tqdm
7
+ import argparse
8
+ import numpy as np
9
+
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ def load_audio(file_path):
14
+ audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
15
+ print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
16
+ #audio = dBgain(audio, -6)
17
+ return torch.from_numpy(audio), samplerate
18
+
19
+ def save_audio(file_path, audio, samplerate=44100):
20
+ #audio = dBgain(audio, +6)
21
+ sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
22
+
23
+ def process_chunk(chunk):
24
+ chunk = chunk.unsqueeze(0).cuda()
25
+ with torch.no_grad():
26
+ return model(chunk).squeeze(0).squeeze(0).cpu()
27
+
28
+ def _getWindowingArray(window_size, fade_size):
29
+ # no fades here in the end, removing the failed ending of the chunk
30
+ fadein = torch.linspace(1, 1, fade_size)
31
+ fadeout = torch.linspace(0, 0, fade_size)
32
+ window = torch.ones(window_size)
33
+ window[-fade_size:] *= fadeout
34
+ window[:fade_size] *= fadein
35
+ return window
36
+
37
+ def dBgain(audio, volume_gain_dB):
38
+ gain = 10 ** (volume_gain_dB / 20)
39
+ gained_audio = audio * gain
40
+ return gained_audio
41
+
42
+
43
+ def main(input_wav, output_wav):
44
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
45
+
46
+ global model
47
+ model = look2hear.models.BaseModel.from_pretrain("/kaggle/working/Apollo/model/pytorch_model.bin", sr=44100, win=20, feature_dim=256, layer=6).cuda()
48
+
49
+ test_data, samplerate = load_audio(input_wav)
50
+
51
+ C = chunk_size * samplerate # chunk_size seconds to samples
52
+ N = overlap
53
+ step = C // N
54
+ fade_size = 2 * 44100 # 2 seconds
55
+ print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
56
+
57
+ border = C - step
58
+
59
+ # Pad the input if necessary
60
+ if test_data.shape[1] > 2 * border and (border > 0):
61
+ test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
62
+
63
+ windowingArray = _getWindowingArray(C, fade_size)
64
+
65
+ result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
66
+ counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
67
+
68
+ i = 0
69
+ progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
70
+
71
+ while i < test_data.shape[1]:
72
+ part = test_data[:, i:i + C]
73
+ length = part.shape[-1]
74
+ if length < C:
75
+ if length > C // 2 + 1:
76
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
77
+ else:
78
+ part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
79
+
80
+ out = process_chunk(part)
81
+
82
+ window = windowingArray
83
+ if i == 0: # First audio chunk, no fadein
84
+ window[:fade_size] = 1
85
+ elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
86
+ window[-fade_size:] = 1
87
+
88
+ result[..., i:i+length] += out[..., :length] * window[..., :length]
89
+ counter[..., i:i+length] += window[..., :length]
90
+
91
+ i += step
92
+ progress_bar.update(step)
93
+
94
+ progress_bar.close()
95
+
96
+ final_output = result / counter
97
+ final_output = final_output.squeeze(0).numpy()
98
+ np.nan_to_num(final_output, copy=False, nan=0.0)
99
+
100
+ # Remove padding if added earlier
101
+ if test_data.shape[1] > 2 * border and (border > 0):
102
+ final_output = final_output[..., border:-border]
103
+
104
+ save_audio(output_wav, final_output, samplerate)
105
+ print(f'Success! Output file saved as {output_wav}')
106
+
107
+ # Memory clearing
108
+ model.cpu()
109
+ del model
110
+ torch.cuda.empty_cache()
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Audio Inference Script")
114
+ parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
115
+ parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
116
+ parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=3)
117
+ parser.add_argument("--overlap", type=int, help="Overlap", default=2)
118
+ args = parser.parse_args()
119
+
120
+ chunk_size = args.chunk_size
121
+ overlap = args.overlap
122
+ print(f'chunk_size = {chunk_size}, overlap = {overlap}')
123
+
124
+ main(args.in_wav, args.out_wav)