SREDWise commited on
Commit
d64d976
1 Parent(s): d6490ab

updated handler.py to resolve tokenization errors.

Browse files
Files changed (1) hide show
  1. handler.py +48 -37
handler.py CHANGED
@@ -31,48 +31,59 @@ class EndpointHandler:
31
  Returns:
32
  Generated text
33
  """
34
- # Handle input
35
- if isinstance(data.get("inputs"), str):
36
- input_text = data["inputs"]
37
- else:
38
- # Extract messages from input
39
- messages = data.get("inputs", {}).get("messages", [])
40
- if not messages:
41
- return {"error": "No messages provided"}
42
-
43
- # Format input text
44
- input_text = ""
45
- for msg in messages:
46
- role = msg.get("role", "")
47
- content = msg.get("content", "")
48
- input_text += f"{role}: {content}\n"
 
49
 
50
- # Get generation parameters
51
- params = {**self.default_params}
52
- if "parameters" in data:
53
- params.update(data["parameters"])
54
 
55
- # Tokenize input
56
- inputs = self.tokenizer(
57
- input_text,
58
- return_tensors="pt",
59
- padding=True,
60
- truncation=True,
61
- max_length=512
62
- )
63
-
64
- # Generate response
65
- with torch.no_grad():
66
- outputs = self.model.generate(
67
- inputs["input_ids"],
68
- attention_mask=inputs["attention_mask"],
69
- **params
70
  )
71
 
72
- # Decode response
73
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
74
 
75
- return [{"generated_text": generated_text}]
 
 
 
 
 
 
 
76
 
77
  def preprocess(self, request):
78
  """
 
31
  Returns:
32
  Generated text
33
  """
34
+ try:
35
+ # Handle input
36
+ if isinstance(data.get("inputs"), str):
37
+ input_text = data["inputs"]
38
+ else:
39
+ # Extract messages from input
40
+ messages = data.get("inputs", {}).get("messages", [])
41
+ if not messages:
42
+ return {"error": "No messages provided"}
43
+
44
+ # Format input text
45
+ input_text = ""
46
+ for msg in messages:
47
+ role = msg.get("role", "")
48
+ content = msg.get("content", "")
49
+ input_text += f"{role}: {content}\n"
50
 
51
+ # Get generation parameters
52
+ params = {**self.default_params}
53
+ if "parameters" in data:
54
+ params.update(data["parameters"])
55
 
56
+ # Ensure proper tokenization with padding and attention mask
57
+ tokenizer_output = self.tokenizer(
58
+ input_text,
59
+ return_tensors="pt",
60
+ padding=True,
61
+ truncation=True,
62
+ max_length=512,
63
+ return_attention_mask=True
 
 
 
 
 
 
 
64
  )
65
 
66
+ # Move tensors to the same device as the model
67
+ input_ids = tokenizer_output["input_ids"]
68
+ attention_mask = tokenizer_output["attention_mask"]
69
+
70
+ # Generate response
71
+ with torch.no_grad():
72
+ outputs = self.model.generate(
73
+ input_ids,
74
+ attention_mask=attention_mask,
75
+ pad_token_id=self.tokenizer.pad_token_id,
76
+ **params
77
+ )
78
 
79
+ # Decode response
80
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+
82
+ return [{"generated_text": generated_text}]
83
+
84
+ except Exception as e:
85
+ print(f"Error in generation: {str(e)}")
86
+ return {"error": str(e)}
87
 
88
  def preprocess(self, request):
89
  """