project / app /ranking.py
kabylake's picture
commit
7bd11ed
import statistics
from typing import List
from typing import Tuple
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from app.config.models.configs import Document
class BCEReranker:
def __init__(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1")
self.model = AutoModelForSequenceClassification.from_pretrained(
"maidalun1020/bce-reranker-base_v1"
)
self.model.eval()
logger.info("Initialized BCE Reranker")
def get_scores(self, query: str, docs: List[Document]) -> List[float]:
logger.info("Reranking documents ... ")
features = [[query, doc.page_content] for doc in docs]
with torch.no_grad():
inputs = self.tokenizer(
features,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
scores = (
self.model(**inputs, return_dict=True)
.logits.view(-1, )
.float()
.tolist()
)
return scores
def rerank(
rerank_model: BCEReranker, query: str, docs: List[Document]
) -> Tuple[float, List[Document]]:
logger.info("Reranking...")
scores = rerank_model.get_scores(query, docs)
for score, d in zip(scores, docs):
d.metadata["score"] = score
sorted_scores = sorted(scores, reverse=True)
logger.info(sorted_scores)
median_ = statistics.mean(sorted_scores[:10])
return median_, [
doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True)
]