SREDWise commited on
Commit
96377a1
1 Parent(s): d64d976

update handler.py as part of debug

Browse files

Resolving generate() error about duplicate pad_token_id suggests we're passing it twice - once in default_params and once explicitly.

Files changed (1) hide show
  1. handler.py +13 -18
handler.py CHANGED
@@ -25,12 +25,6 @@ class EndpointHandler:
25
  }
26
 
27
  def __call__(self, data: Dict):
28
- """
29
- Args:
30
- data: Dictionary with either string input or structured messages
31
- Returns:
32
- Generated text
33
- """
34
  try:
35
  # Handle input
36
  if isinstance(data.get("inputs"), str):
@@ -41,19 +35,23 @@ class EndpointHandler:
41
  if not messages:
42
  return {"error": "No messages provided"}
43
 
44
- # Format input text
45
- input_text = ""
46
  for msg in messages:
47
  role = msg.get("role", "")
48
  content = msg.get("content", "")
49
- input_text += f"{role}: {content}\n"
 
50
 
51
  # Get generation parameters
52
  params = {**self.default_params}
53
  if "parameters" in data:
54
  params.update(data["parameters"])
 
 
 
55
 
56
- # Ensure proper tokenization with padding and attention mask
57
  tokenizer_output = self.tokenizer(
58
  input_text,
59
  return_tensors="pt",
@@ -63,22 +61,19 @@ class EndpointHandler:
63
  return_attention_mask=True
64
  )
65
 
66
- # Move tensors to the same device as the model
67
- input_ids = tokenizer_output["input_ids"]
68
- attention_mask = tokenizer_output["attention_mask"]
69
-
70
  # Generate response
71
  with torch.no_grad():
72
  outputs = self.model.generate(
73
- input_ids,
74
- attention_mask=attention_mask,
75
- pad_token_id=self.tokenizer.pad_token_id,
76
  **params
77
  )
78
 
79
- # Decode response
80
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
 
 
82
  return [{"generated_text": generated_text}]
83
 
84
  except Exception as e:
 
25
  }
26
 
27
  def __call__(self, data: Dict):
 
 
 
 
 
 
28
  try:
29
  # Handle input
30
  if isinstance(data.get("inputs"), str):
 
35
  if not messages:
36
  return {"error": "No messages provided"}
37
 
38
+ # Format input text as array
39
+ inputs = []
40
  for msg in messages:
41
  role = msg.get("role", "")
42
  content = msg.get("content", "")
43
+ inputs.append(f"{role}: {content}")
44
+ input_text = "\n".join(inputs)
45
 
46
  # Get generation parameters
47
  params = {**self.default_params}
48
  if "parameters" in data:
49
  params.update(data["parameters"])
50
+
51
+ # Remove pad_token_id from params if it's going to be set explicitly
52
+ params.pop('pad_token_id', None)
53
 
54
+ # Tokenize input
55
  tokenizer_output = self.tokenizer(
56
  input_text,
57
  return_tensors="pt",
 
61
  return_attention_mask=True
62
  )
63
 
 
 
 
 
64
  # Generate response
65
  with torch.no_grad():
66
  outputs = self.model.generate(
67
+ tokenizer_output["input_ids"],
68
+ attention_mask=tokenizer_output["attention_mask"],
69
+ pad_token_id=self.tokenizer.pad_token_id, # Set it only here
70
  **params
71
  )
72
 
73
+ # Decode response and ensure array output
74
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
75
 
76
+ # Always return an array as required by the endpoint
77
  return [{"generated_text": generated_text}]
78
 
79
  except Exception as e: