English
music
music-captioning
Inference Endpoints
lp-music-caps / handler.py
ivillar's picture
Change key
9d408f3
import torch
from model.bart import BartCaptionModel
from utils.audio_utils import load_audio, STR_CH_FIRST
from typing import Dict, List, Any
import numpy as np
import librosa
import os
import json
def preprocess_audio(audio_signal, sr, duration=10, target_sr=16000):
n_samples = int(duration * target_sr)
audio = librosa.to_mono(audio_signal)
audio = librosa.resample(audio, orig_sr = sr, target_sr = target_sr)
if len(audio.shape) == 2:
audio = audio.mean(0, False) # to mono
input_size = int(n_samples)
if audio.shape[-1] < input_size: # pad sequence
pad = np.zeros(input_size)
pad[: audio.shape[-1]] = audio
audio = pad
ceil = int(audio.shape[-1] // n_samples)
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
return audio
class EndpointHandler:
def __init__(self, path=""):
if os.path.isfile("transfer.pth") == False:
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth')
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = BartCaptionModel(max_length = 128)
pretrained_object = torch.load('./transfer.pth', map_location='cpu')
state_dict = pretrained_object['state_dict']
self.model.load_state_dict(state_dict)
if torch.cuda.is_available():
torch.cuda.set_device(self.device)
self.model = self.model.cuda(self.device)
def _captioning(self, audio_tensor):
if self.device is not None:
audio_tensor = audio_tensor.to(self.device)
with torch.no_grad():
output = self.model.generate(
samples=audio_tensor,
num_beams=5,
)
inference = ""
number_of_chunks = range(audio_tensor.shape[0])
for chunk, text in zip(number_of_chunks, output):
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
inference += f"{time}\n{text} \n \n"
return inference
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
data = json.loads(data["inputs"])
array = np.array(data['audio_list'], dtype=data["audio_dtype"])
array_shape = data['audio_shape']
input_audio = array.reshape(array_shape)
sr = data["sampling_rate"]
preprocessed_audio = preprocess_audio(input_audio, sr)
return self._captioning(preprocessed_audio)
"""
if __name__ == "__main__":
import numpy as np
from scipy.io.wavfile import write as wav_write
from huggingface_hub import InferenceApi
handler = EndpointHandler()
audio_path = "folk.wav"
np_audio, sr = librosa.load(audio_path, sr=44100)
np_list = np_audio.tolist()
np_shape = np_audio.shape
np_dtype = np_audio.dtype.name
request = json.dumps({
"audio_list": np_list,
"audio_shape": np_shape,
"audio_dtype": np_dtype,
"sampling_rate": sr
})
print(f"Loaded {audio_path} with sample rate {sr}")
print(handler.__call__({"payload": request}))
"""