|
import torch |
|
from model.bart import BartCaptionModel |
|
from utils.audio_utils import load_audio, STR_CH_FIRST |
|
from typing import Dict, List, Any |
|
import numpy as np |
|
import librosa |
|
import os |
|
import json |
|
def preprocess_audio(audio_signal, sr, duration=10, target_sr=16000): |
|
n_samples = int(duration * target_sr) |
|
audio = librosa.to_mono(audio_signal) |
|
audio = librosa.resample(audio, orig_sr = sr, target_sr = target_sr) |
|
|
|
if len(audio.shape) == 2: |
|
audio = audio.mean(0, False) |
|
input_size = int(n_samples) |
|
if audio.shape[-1] < input_size: |
|
pad = np.zeros(input_size) |
|
pad[: audio.shape[-1]] = audio |
|
audio = pad |
|
ceil = int(audio.shape[-1] // n_samples) |
|
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) |
|
return audio |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
if os.path.isfile("transfer.pth") == False: |
|
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth') |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
self.model = BartCaptionModel(max_length = 128) |
|
pretrained_object = torch.load('./transfer.pth', map_location='cpu') |
|
state_dict = pretrained_object['state_dict'] |
|
self.model.load_state_dict(state_dict) |
|
if torch.cuda.is_available(): |
|
torch.cuda.set_device(self.device) |
|
self.model = self.model.cuda(self.device) |
|
|
|
def _captioning(self, audio_tensor): |
|
if self.device is not None: |
|
audio_tensor = audio_tensor.to(self.device) |
|
|
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
samples=audio_tensor, |
|
num_beams=5, |
|
) |
|
inference = "" |
|
number_of_chunks = range(audio_tensor.shape[0]) |
|
for chunk, text in zip(number_of_chunks, output): |
|
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]" |
|
inference += f"{time}\n{text} \n \n" |
|
return inference |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
data = json.loads(data["inputs"]) |
|
array = np.array(data['audio_list'], dtype=data["audio_dtype"]) |
|
array_shape = data['audio_shape'] |
|
input_audio = array.reshape(array_shape) |
|
sr = data["sampling_rate"] |
|
|
|
preprocessed_audio = preprocess_audio(input_audio, sr) |
|
|
|
return self._captioning(preprocessed_audio) |
|
""" |
|
if __name__ == "__main__": |
|
import numpy as np |
|
from scipy.io.wavfile import write as wav_write |
|
from huggingface_hub import InferenceApi |
|
|
|
handler = EndpointHandler() |
|
audio_path = "folk.wav" |
|
np_audio, sr = librosa.load(audio_path, sr=44100) |
|
|
|
np_list = np_audio.tolist() |
|
np_shape = np_audio.shape |
|
np_dtype = np_audio.dtype.name |
|
|
|
request = json.dumps({ |
|
"audio_list": np_list, |
|
"audio_shape": np_shape, |
|
"audio_dtype": np_dtype, |
|
"sampling_rate": sr |
|
}) |
|
|
|
print(f"Loaded {audio_path} with sample rate {sr}") |
|
print(handler.__call__({"payload": request})) |
|
""" |