|
from typing import Any, Dict, List |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
PRETRAINED_MODEL_NAME = "facebook/bart-large" |
|
ADAPTER_MODEL_NAME = "jammygrams/bart-qa" |
|
ADAPTER_NAME = "narrativeqa" |
|
|
|
class EndpointHandler(): |
|
def __init__( |
|
self, |
|
path: str, |
|
): |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
self.model.set_active_adapters(ADAPTER_NAME) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
output :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data.pop("inputs", data) |
|
|
|
tokenized_input = self.tokenizer([inputs], return_tensors="pt") |
|
prediction = self.model.generate( |
|
tokenized_input.input_ids, |
|
num_beams=5, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_length=50, |
|
) |
|
output = self.tokenizer.decode( |
|
prediction["sequences"][0], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
|
|
return [output] |