|
from typing import List, Optional, Union |
|
|
|
from infinity.tasks import TextClassificationEndpoint, TextClassificationOutput, \ |
|
TextClassificationParams |
|
from optimum.onnxruntime import ORTModelForSequenceClassification |
|
from transformers import pipeline, AutoTokenizer |
|
|
|
class BankingEndpoint(TextClassificationEndpoint): |
|
|
|
__slots__ = ("_pipeline", ) |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self._pipeline: Optional[ORTModelForSequenceClassification] = None |
|
|
|
def initialize(self, **kwargs): |
|
print("Initializing") |
|
model = ORTModelForSequenceClassification.from_pretrained("philschmid/distilbert-onnx-banking77") |
|
tokenizer = AutoTokenizer.from_pretrained("philschmid/distilbert-onnx-banking77") |
|
|
|
self._pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer) |
|
print("INITIALIZED") |
|
|
|
|
|
def handle( |
|
self, |
|
inputs: Union[str, List[str]], |
|
parameters: TextClassificationParams |
|
) -> List[TextClassificationOutput]: |
|
return self._pipeline(inputs, **parameters) |
|
|