SREDWise commited on
Commit
6e6f8a3
·
verified ·
1 Parent(s): 2629eb3

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +87 -0
handler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path: str):
7
+ # Load model and tokenizer
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(
10
+ path,
11
+ torch_dtype=torch.float32, # Use float32 for CPU
12
+ device_map="auto"
13
+ )
14
+
15
+ # Set up generation parameters
16
+ self.default_params = {
17
+ "max_length": 1000,
18
+ "temperature": 0.7,
19
+ "top_p": 0.7,
20
+ "top_k": 50,
21
+ "repetition_penalty": 1.0,
22
+ "do_sample": True,
23
+ "pad_token_id": self.tokenizer.pad_token_id,
24
+ "eos_token_id": self.tokenizer.eos_token_id
25
+ }
26
+
27
+ def __call__(self, data: Dict):
28
+ """
29
+ Args:
30
+ data: Dictionary with "inputs" and optional "parameters"
31
+ Returns:
32
+ Generated text
33
+ """
34
+ # Extract messages from input
35
+ messages = data.get("inputs", {}).get("messages", [])
36
+ if not messages:
37
+ return {"error": "No messages provided"}
38
+
39
+ # Format input text
40
+ input_text = ""
41
+ for msg in messages:
42
+ role = msg.get("role", "")
43
+ content = msg.get("content", "")
44
+ input_text += f"{role}: {content}\n"
45
+
46
+ # Get generation parameters
47
+ params = {**self.default_params}
48
+ if "parameters" in data:
49
+ params.update(data["parameters"])
50
+
51
+ # Tokenize input
52
+ inputs = self.tokenizer(
53
+ input_text,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=True,
57
+ max_length=512
58
+ )
59
+
60
+ # Generate response
61
+ with torch.no_grad():
62
+ outputs = self.model.generate(
63
+ inputs["input_ids"],
64
+ attention_mask=inputs["attention_mask"],
65
+ **params
66
+ )
67
+
68
+ # Decode response
69
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+
71
+ return [{"generated_text": generated_text}]
72
+
73
+ def preprocess(self, request):
74
+ """
75
+ Prepare request for inference
76
+ """
77
+ if request.content_type != "application/json":
78
+ raise ValueError("Content type must be application/json")
79
+
80
+ data = request.json
81
+ return data
82
+
83
+ def postprocess(self, data):
84
+ """
85
+ Post-process model output
86
+ """
87
+ return data