|
from transformers.utils import logging |
|
from FlagEmbedding import FlagReranker |
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger("transformers") |
|
logger.info("INFO") |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) |
|
|
|
def __call__(self, inputs): |
|
data = inputs['inputs'] |
|
logger.info("Inference started") |
|
|
|
logger.info(type(data)) |
|
scores = [] |
|
for t in data['texts']: |
|
score = self.reranker.compute_score([data['query'], t]) |
|
logger.info(score) |
|
scores.append(score) |
|
|
|
output = {"scores": scores} |
|
return output |
|
|
|
|