bart-qa / handler.py
jammygrams's picture
Upload handler.py
d2de3cf
from typing import Any, Dict, List
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
PRETRAINED_MODEL_NAME = "facebook/bart-large"
ADAPTER_MODEL_NAME = "jammygrams/bart-qa"
ADAPTER_NAME = "narrativeqa"
class EndpointHandler():
def __init__(
self,
path: str,
):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.model.set_active_adapters(ADAPTER_NAME)
def __call__(self, data: Dict[str, Any]) -> str:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
output :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
# test_sample = self.tokenizer([user_input], truncation=True, max_length=1024, return_tensors='pt')
tokenized_input = self.tokenizer([inputs], return_tensors="pt")
prediction = self.model.generate(
tokenized_input.input_ids,
num_beams=5,
return_dict_in_generate=True,
output_scores=True,
max_length=50,
)
output = self.tokenizer.decode(
prediction["sequences"][0], # single prediction
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return [output]