File size: 1,007 Bytes
ef30bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Dict, List, Any, Optional
import transformers
import torch

MAX_TOKENS=4096

class EndpointHandler(object):
    def __init__(self, path=''):
        self.pipeline: transformers.Pipeline = transformers.pipeline(
            "text-generation",
            model="ai-singapore/gemma2-9b-cpt-sealionv3-instruct",
            model_kwargs={"torch_dtype": torch.bfloat16 },
            device_map="auto",
        )

    def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
        """
        :param data:
            inputs: message format
            parameters: parameters for the pipeline
        :return:
        """
        inputs = data.pop("inputs")
        parameters: Optional[Dict] = data.pop("parameters", None)

        if parameters is not None:
            outputs = self.pipeline(
                inputs,
                **parameters
            )
        else:
            outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS)

        return outputs