File size: 2,789 Bytes
c480b17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import Dict, List, Any
import numpy as np
import pickle
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
import torch
from eurovoc import EurovocTagger
BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
MAX_LEN = 512
TEXT_MAX_LEN = MAX_LEN * 50
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
class EndpointHandler:
mlb = MultiLabelBinarizer()
def __init__(self, path=""):
self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = EurovocTagger.from_pretrained(path,
bert_model_name=BERT_MODEL_NAME,
n_classes=len(self.mlb.classes_),
map_location=self.device)
self.model.eval()
self.model.freeze()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
text = data.pop("inputs", data)
topk = data.pop("topk", 5)
threshold = data.pop("threshold", 0.16)
debug = data.pop("debug", False)
prediction = self.get_prediction(text)
results = [{"label": label, "score": float(score)} for label, score in
zip(self.mlb.classes_, prediction[0].tolist())]
results = sorted(results, key=lambda x: x["score"], reverse=True)
results = [r for r in results if r["score"] > threshold]
results = results[:topk]
if debug:
return {"results": results, "values": prediction, "input": text}
else:
return {"results": results}
def get_prediction(self, text):
# split text into chunks of MAX_LEN and get average prediction for each chunk
chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
predictions = [self._get_prediction(chunk) for chunk in chunks]
predictions = np.array(predictions).mean(axis=0)
return predictions
def _get_prediction(self, text):
item = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=MAX_LEN,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt')
_, prediction = self.model(item["input_ids"], item["attention_mask"])
prediction = prediction.cpu().detach().numpy()
return prediction
|