bart-large-cnn / handler.py
philipp-zettl's picture
Update handler.py
2741c37 verified
raw
history blame contribute delete
994 Bytes
from typing import Dict, List, Any
from optimum.pipelines import pipeline
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSeq2SeqLM
class EndpointHandler():
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
model = ORTModelForSeq2SeqLM.from_pretrained(path)
self.pipeline = pipeline("summarization",model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
return prediction