File size: 1,370 Bytes
4a87051
 
702d195
 
 
 
 
 
4a87051
 
 
 
 
 
 
 
 
 
702d195
8f0d639
702d195
 
 
 
 
 
 
 
4a87051
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torchvision import transforms
from pair_classification import PairClassificationPipeline

class PreTrainedPipeline():
    def __init__(self, path):
        """
        Initialize model
        """
        model_flag = 'google/vit-base-patch16-224-in21k'
        # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
        self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag , 
                model_kwargs={'num_labels':len(label2id),
                            'label2id':label2id,
                            'id2label':id2label,
                            'num_channels':6,
                            'ignore_mismatched_sizes': True })
        self.model = self.pipe.model.from_pretrained(path)
        

    def __call__(self, inputs):
        """
        Args:
            inputs (:obj:`np.array`):
                The raw waveform of audio received. By default at 16KHz.
        Return:
            A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
            the detected text from the input audio.
        """
        # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values  # Batch size 1
        # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
        return self.pipe(inputs)