stupidog04's picture
Update pipeline.py
4a87051
raw
history blame
1.39 kB
from torchvision import transforms
from pair_classification import PairClassificationPipeline
class PreTrainedPipeline():
def __init__(self, path):
"""
Initialize model
"""
model_flag = 'google/vit-base-patch16-224-in21k'
# self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag ,
model_kwargs={'num_labels':len(label2id),
'label2id':label2id,
'id2label':id2label,
'num_channels':6,
'ignore_mismatched_sizes': True })
self.model = self.pipe.model.from_pretrained(path)
def __call__(self, inputs)-> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
# input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
# logits = self.model(input_values).logits.cpu().detach().numpy()[0]
return self.pipe(inputs)