Text Classification
PyTorch
English
eurovoc
Inference Endpoints
eurovoc_en / handler.py
scampion's picture
initial commit
b552d82 verified
raw
history blame
2.79 kB
from typing import Dict, List, Any
import numpy as np
import pickle
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
import torch
from eurovoc import EurovocTagger
BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
MAX_LEN = 512
TEXT_MAX_LEN = MAX_LEN * 50
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
class EndpointHandler:
mlb = MultiLabelBinarizer()
def __init__(self, path=""):
self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = EurovocTagger.from_pretrained(path,
bert_model_name=BERT_MODEL_NAME,
n_classes=len(self.mlb.classes_),
map_location=self.device)
self.model.eval()
self.model.freeze()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
text = data.pop("inputs", data)
topk = data.pop("topk", 5)
threshold = data.pop("threshold", 0.16)
debug = data.pop("debug", False)
prediction = self.get_prediction(text)
results = [{"label": label, "score": float(score)} for label, score in
zip(self.mlb.classes_, prediction[0].tolist())]
results = sorted(results, key=lambda x: x["score"], reverse=True)
results = [r for r in results if r["score"] > threshold]
results = results[:topk]
if debug:
return {"results": results, "values": prediction, "input": text}
else:
return {"results": results}
def get_prediction(self, text):
# split text into chunks of MAX_LEN and get average prediction for each chunk
chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
predictions = [self._get_prediction(chunk) for chunk in chunks]
predictions = np.array(predictions).mean(axis=0)
return predictions
def _get_prediction(self, text):
item = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=MAX_LEN,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt')
_, prediction = self.model(item["input_ids"], item["attention_mask"])
prediction = prediction.cpu().detach().numpy()
return prediction