SREDWise commited on
Commit
8cf54e4
1 Parent(s): 96377a1

update handler.py for debugging

Browse files

tensor size mismatch error typically occurs when there's an inconsistency between model configuration and input processing. Fixing by:
Uses fixed padding with max_length
Adds debug printing
Simplifies input handling
Uses consistent tensor dimensions

Files changed (1) hide show
  1. handler.py +19 -27
handler.py CHANGED
@@ -30,54 +30,46 @@ class EndpointHandler:
30
  if isinstance(data.get("inputs"), str):
31
  input_text = data["inputs"]
32
  else:
33
- # Extract messages from input
34
- messages = data.get("inputs", {}).get("messages", [])
35
- if not messages:
36
- return {"error": "No messages provided"}
37
-
38
- # Format input text as array
39
- inputs = []
40
- for msg in messages:
41
- role = msg.get("role", "")
42
- content = msg.get("content", "")
43
- inputs.append(f"{role}: {content}")
44
- input_text = "\n".join(inputs)
45
 
46
- # Get generation parameters
47
- params = {**self.default_params}
48
- if "parameters" in data:
49
- params.update(data["parameters"])
50
-
51
- # Remove pad_token_id from params if it's going to be set explicitly
52
- params.pop('pad_token_id', None)
53
 
54
- # Tokenize input
55
  tokenizer_output = self.tokenizer(
56
  input_text,
57
- return_tensors="pt",
58
- padding=True,
59
  truncation=True,
60
- max_length=512,
 
61
  return_attention_mask=True
62
  )
63
 
 
 
 
 
64
  # Generate response
65
  with torch.no_grad():
66
  outputs = self.model.generate(
67
  tokenizer_output["input_ids"],
68
  attention_mask=tokenizer_output["attention_mask"],
69
- pad_token_id=self.tokenizer.pad_token_id, # Set it only here
70
- **params
 
 
 
 
71
  )
72
 
73
- # Decode response and ensure array output
74
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
75
 
76
- # Always return an array as required by the endpoint
77
  return [{"generated_text": generated_text}]
78
 
79
  except Exception as e:
80
  print(f"Error in generation: {str(e)}")
 
81
  return {"error": str(e)}
82
 
83
  def preprocess(self, request):
 
30
  if isinstance(data.get("inputs"), str):
31
  input_text = data["inputs"]
32
  else:
33
+ input_text = data.get("inputs")[0] if isinstance(data.get("inputs"), list) else str(data.get("inputs"))
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Print debug information
36
+ print(f"Input text: {input_text}")
 
 
 
 
 
37
 
38
+ # Tokenize with fixed dimensions
39
  tokenizer_output = self.tokenizer(
40
  input_text,
41
+ padding='max_length', # Changed to max_length
 
42
  truncation=True,
43
+ max_length=512, # Fixed length
44
+ return_tensors="pt",
45
  return_attention_mask=True
46
  )
47
 
48
+ # Print tensor shapes for debugging
49
+ print(f"Input ids shape: {tokenizer_output['input_ids'].shape}")
50
+ print(f"Attention mask shape: {tokenizer_output['attention_mask'].shape}")
51
+
52
  # Generate response
53
  with torch.no_grad():
54
  outputs = self.model.generate(
55
  tokenizer_output["input_ids"],
56
  attention_mask=tokenizer_output["attention_mask"],
57
+ max_length=512,
58
+ pad_token_id=self.tokenizer.pad_token_id,
59
+ do_sample=True,
60
+ temperature=0.7,
61
+ top_p=0.7,
62
+ top_k=50
63
  )
64
 
65
+ # Decode response
66
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
 
 
68
  return [{"generated_text": generated_text}]
69
 
70
  except Exception as e:
71
  print(f"Error in generation: {str(e)}")
72
+ print(f"Model config: {self.model.config}")
73
  return {"error": str(e)}
74
 
75
  def preprocess(self, request):