Spaces:
Configuration error
Configuration error
import statistics | |
from typing import List | |
from typing import Tuple | |
import torch | |
from loguru import logger | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from app.config.models.configs import Document | |
class BCEReranker: | |
def __init__(self) -> None: | |
self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1") | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
"maidalun1020/bce-reranker-base_v1" | |
) | |
self.model.eval() | |
logger.info("Initialized BCE Reranker") | |
def get_scores(self, query: str, docs: List[Document]) -> List[float]: | |
logger.info("Reranking documents ... ") | |
features = [[query, doc.page_content] for doc in docs] | |
with torch.no_grad(): | |
inputs = self.tokenizer( | |
features, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt", | |
) | |
scores = ( | |
self.model(**inputs, return_dict=True) | |
.logits.view(-1, ) | |
.float() | |
.tolist() | |
) | |
return scores | |
def rerank( | |
rerank_model: BCEReranker, query: str, docs: List[Document] | |
) -> Tuple[float, List[Document]]: | |
logger.info("Reranking...") | |
scores = rerank_model.get_scores(query, docs) | |
for score, d in zip(scores, docs): | |
d.metadata["score"] = score | |
sorted_scores = sorted(scores, reverse=True) | |
logger.info(sorted_scores) | |
median_ = statistics.mean(sorted_scores[:10]) | |
return median_, [ | |
doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True) | |
] | |