|
from typing import Dict, List, Any |
|
from sentence_transformers import SentenceTransformer |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = SentenceTransformer(path, trust_remote_code=True).cuda() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: List[`str`]) |
|
type (:obj: `str`) 'query' || 'doc' |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs",data) |
|
request_type = data.pop("type", 'doc') |
|
|
|
if request_type == 'query': |
|
return self.model.encode(inputs, prompt_name='s2p_query') |
|
elif request_type == 'doc': |
|
return self.model.encode(inputs) |
|
else: |
|
raise Exception("Invalid request type") |
|
|