osanseviero's picture
Update pipeline.py
b00ee09
raw
history blame
1.64 kB
from typing import Dict, List, Tuple
import numpy as np
from asteroid import separate
from asteroid.models import BaseModel
import os
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.model = BaseModel.from_pretrained(os.path.join(path, "pytorch_model.bin"))
self.sampling_rate = self.model.sample_rate
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default sampled at `self.sampling_rate`.
The shape of this array is `T`, where `T` is the time axis
Return:
A :obj:`tuple` containing:
- :obj:`np.array`:
The return shape of the array must be `C'`x`T'`
- a :obj:`int`: the sampling rate as an int in Hz.
- a :obj:`List[str]`: the annotation for each out channel.
This can be the name of the instruments for audio source separation
or some annotation for speech enhancement. The length must be `C'`.
"""
separated = separate.numpy_separate(self.model, inputs.reshape((1, 1, -1)))
out = separated[0]
n = out.shape[0]
labels = [f"label_{i}" for i in range(n)]
return separated[0], int(self.model.sample_rate), labels