from typing import Dict, List, Any, Optional import transformers import torch MAX_TOKENS=4096 class EndpointHandler(object): def __init__(self, path=''): self.pipeline: transformers.Pipeline = transformers.pipeline( "text-generation", model="ai-singapore/gemma2-9b-cpt-sealionv3-instruct", model_kwargs={"torch_dtype": torch.bfloat16 }, device_map="auto", ) def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: """ :param data: inputs: message format parameters: parameters for the pipeline :return: """ inputs = data.pop("inputs") parameters: Optional[Dict] = data.pop("parameters", None) if parameters is not None: outputs = self.pipeline( inputs, **parameters ) else: outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) return outputs