lr-bert-base-uncased / handler.py
Uan Sholanbayev
added model
582955e
raw
history blame
1.46 kB
from typing import List, Dict, Any
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import pickle
def unpickle_obj(filepath):
with open(filepath, 'rb') as f_in:
data = pickle.load(f_in)
print(f"unpickled {filepath}")
return data
class EndpointHandler():
def __init__(self, path=""):
self.model = unpickle_obj(f"{path}/bert_lr.pkl")
self.tokenizer = BertTokenizer.from_pretrained(path, local_files_only=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bert = BertModel.from_pretrained(path).to(self.device)
def get_embeddings(self, texts: List[str]):
inputs = self.tokenizer(texts, return_tensors='pt', truncation=True,
padding=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.bert(**inputs)
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs",data)
queries = inputs['queries']
texts = inputs['texts']
queries_vec = self.get_embeddings(queries)
texts_vec = self.get_embeddings(texts)
diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\
.reshape(-1, len(queries_vec[0]))
return [{
"outputs": self.model.predict_proba(diff)
}]