stella_1.5B_custom / handler.py
lgbird's picture
Upload 2 files
a6b176e verified
raw
history blame
872 Bytes
from typing import Dict, List, Any
from sentence_transformers import SentenceTransformer
class EndpointHandler():
def __init__(self, path=""):
self.model = SentenceTransformer(path, trust_remote_code=True).cuda()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: List[`str`])
type (:obj: `str`) 'query' || 'doc'
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs",data)
request_type = data.pop("type", 'doc')
if request_type == 'query':
return self.model.encode(inputs, prompt_name='s2p_query')
elif request_type == 'doc':
return self.model.encode(inputs)
else:
raise Exception("Invalid request type")