from typing import Dict, List, Any, Optional | |
import transformers | |
import torch | |
MAX_TOKENS=4096 | |
class EndpointHandler(object): | |
def __init__(self, path=''): | |
self.pipeline: transformers.Pipeline = transformers.pipeline( | |
"text-generation", | |
model="ai-singapore/gemma2-9b-cpt-sealionv3-instruct", | |
model_kwargs={"torch_dtype": torch.bfloat16 }, | |
device_map="auto", | |
) | |
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: | |
""" | |
:param data: | |
inputs: message format | |
parameters: parameters for the pipeline | |
:return: | |
""" | |
inputs = data.pop("inputs") | |
parameters: Optional[Dict] = data.pop("parameters", None) | |
if parameters is not None: | |
outputs = self.pipeline( | |
inputs, | |
**parameters | |
) | |
else: | |
outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) | |
return outputs | |