JamesANZ commited on
Commit
d9ee2fd
·
1 Parent(s): 0506d30

Improve generation quality without retraining: add candidate reranking, few-shot prompts, aggressive filtering, and quality scoring

Browse files
Files changed (1) hide show
  1. query_slm.py +279 -31
query_slm.py CHANGED
@@ -72,20 +72,24 @@ class LegalSLM:
72
  def generate_answer(
73
  self,
74
  question: str,
75
- temperature: float = 0.4,
76
- max_length: int = 250,
77
- top_p: float = 0.9,
78
- top_k: int = 50
 
 
79
  ) -> str:
80
  """
81
- Generate an answer to a legal question.
82
 
83
  Args:
84
  question: The legal question to answer
85
  temperature: Sampling temperature (lower = more deterministic)
86
- max_length: Maximum length of generated response
 
87
  top_p: Nucleus sampling parameter
88
  top_k: Top-k sampling parameter
 
89
 
90
  Returns:
91
  Generated answer text
@@ -101,53 +105,277 @@ class LegalSLM:
101
  raise ValueError("Temperature must be between 0.0 and 2.0")
102
  if max_length < 1 or max_length > 1000:
103
  raise ValueError("max_length must be between 1 and 1000")
 
 
104
 
105
- # Build prompt
106
- prompt = f"Based on Australian legal documents, answer the following.\n\nQuestion: {sanitized_question}\nAnswer:"
107
 
108
- # Tokenize prompt with attention mask to fix the warning
109
- # Using tokenizer() instead of encode() to get attention_mask automatically
110
  tokenized = self.tokenizer(
111
  prompt,
112
  return_tensors='pt',
113
- padding=False, # No padding needed for single input
114
  truncation=False
115
  )
116
  input_ids = tokenized['input_ids'].to(self.device)
117
  attention_mask = tokenized['attention_mask'].to(self.device)
118
 
119
- # Generate
 
 
120
  with torch.no_grad():
121
  outputs = self.model.generate(
122
  input_ids,
123
- attention_mask=attention_mask, # Pass attention mask to fix warning
124
- max_length=input_ids.shape[1] + max_length,
 
 
125
  temperature=temperature,
126
  top_p=top_p,
127
  top_k=top_k,
128
  do_sample=True,
129
  pad_token_id=self.tokenizer.pad_token_id,
130
  eos_token_id=self.tokenizer.eos_token_id,
131
- repetition_penalty=1.2, # Reduce repetition
 
 
132
  )
133
 
134
- # Decode response
135
- full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- # Extract just the answer part (after "Answer:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if "Answer:" in full_response:
139
- answer = full_response.split("Answer:")[-1].strip()
140
- else:
141
- # Fallback: return everything after the prompt
142
- # Safety check: ensure prompt is not longer than response
143
- if len(prompt) <= len(full_response):
144
- answer = full_response[len(prompt):].strip()
145
  else:
146
- # If prompt is longer (shouldn't happen, but handle gracefully)
147
  answer = full_response.strip()
 
 
 
148
 
149
  return answer
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def interactive_query(self):
152
  """Run interactive query loop."""
153
  print("\n" + "=" * 80)
@@ -169,7 +397,14 @@ class LegalSLM:
169
 
170
  print("\nGenerating answer...")
171
  try:
172
- answer = self.generate_answer(question)
 
 
 
 
 
 
 
173
  print(f"\nAnswer: {answer}\n")
174
  except ValueError as e:
175
  print(f"\nInvalid input: {e}\n")
@@ -207,14 +442,25 @@ def main():
207
  parser.add_argument(
208
  '--temperature',
209
  type=float,
210
- default=0.4,
211
- help='Sampling temperature (default: 0.4)'
212
  )
213
  parser.add_argument(
214
  '--max-length',
215
  type=int,
216
- default=250,
217
- help='Maximum response length in tokens (default: 250)'
 
 
 
 
 
 
 
 
 
 
 
218
  )
219
 
220
  args = parser.parse_args()
@@ -250,7 +496,9 @@ def main():
250
  answer = slm.generate_answer(
251
  args.question,
252
  temperature=args.temperature,
253
- max_length=args.max_length
 
 
254
  )
255
  print(f"\nQuestion: {args.question}")
256
  print(f"Answer: {answer}\n")
 
72
  def generate_answer(
73
  self,
74
  question: str,
75
+ temperature: float = 0.2,
76
+ max_length: int = 200,
77
+ num_candidates: int = 3,
78
+ top_p: float = 0.85,
79
+ top_k: int = 30,
80
+ use_reranking: bool = True
81
  ) -> str:
82
  """
83
+ Generate an answer to a legal question with quality improvements.
84
 
85
  Args:
86
  question: The legal question to answer
87
  temperature: Sampling temperature (lower = more deterministic)
88
+ max_length: Maximum length of generated response in tokens
89
+ num_candidates: Number of candidates to generate for reranking
90
  top_p: Nucleus sampling parameter
91
  top_k: Top-k sampling parameter
92
+ use_reranking: If True, generate multiple candidates and pick best
93
 
94
  Returns:
95
  Generated answer text
 
105
  raise ValueError("Temperature must be between 0.0 and 2.0")
106
  if max_length < 1 or max_length > 1000:
107
  raise ValueError("max_length must be between 1 and 1000")
108
+ if num_candidates < 1 or num_candidates > 10:
109
+ raise ValueError("num_candidates must be between 1 and 10")
110
 
111
+ # Build prompt with few-shot examples for better quality
112
+ prompt = self._build_few_shot_prompt(sanitized_question)
113
 
114
+ # Tokenize prompt with attention mask
 
115
  tokenized = self.tokenizer(
116
  prompt,
117
  return_tensors='pt',
118
+ padding=False,
119
  truncation=False
120
  )
121
  input_ids = tokenized['input_ids'].to(self.device)
122
  attention_mask = tokenized['attention_mask'].to(self.device)
123
 
124
+ # Generate candidates
125
+ num_to_generate = num_candidates if use_reranking else 1
126
+
127
  with torch.no_grad():
128
  outputs = self.model.generate(
129
  input_ids,
130
+ attention_mask=attention_mask,
131
+ max_new_tokens=max_length,
132
+ num_return_sequences=num_to_generate,
133
+ min_new_tokens=30, # Force minimum answer length
134
  temperature=temperature,
135
  top_p=top_p,
136
  top_k=top_k,
137
  do_sample=True,
138
  pad_token_id=self.tokenizer.pad_token_id,
139
  eos_token_id=self.tokenizer.eos_token_id,
140
+ repetition_penalty=1.4, # Higher to reduce repetition
141
+ no_repeat_ngram_size=4, # Prevent 4-gram repetition
142
+ early_stopping=False,
143
  )
144
 
145
+ # Extract and process candidates
146
+ candidates = []
147
+ for output in outputs:
148
+ full_response = self.tokenizer.decode(output, skip_special_tokens=True)
149
+ answer = self._extract_answer(full_response, prompt)
150
+ answer = self._clean_and_filter_answer(answer)
151
+
152
+ # Only consider answers that pass quality checks
153
+ if answer and len(answer.strip()) > 30:
154
+ if use_reranking:
155
+ score = self._score_answer_quality(answer)
156
+ candidates.append((score, answer))
157
+ else:
158
+ # Return first valid answer if not reranking
159
+ return self._find_natural_stopping_point(answer, max_chars=600)
160
+
161
+ # If no valid candidates, return fallback
162
+ if not candidates:
163
+ return "I cannot provide a reliable answer to this question based on the available information."
164
+
165
+ # Return best candidate based on quality score
166
+ candidates.sort(key=lambda x: x[0], reverse=True)
167
+ best_answer = candidates[0][1]
168
+
169
+ # Final cleanup and length limit
170
+ best_answer = self._find_natural_stopping_point(best_answer, max_chars=600)
171
+ return best_answer
172
+
173
+ def _build_few_shot_prompt(self, question: str) -> str:
174
+ """
175
+ Build prompt with few-shot examples to guide the model.
176
 
177
+ Args:
178
+ question: The user's question
179
+
180
+ Returns:
181
+ Formatted prompt with examples
182
+ """
183
+ # Few-shot examples that demonstrate good answer format
184
+ examples = [
185
+ ("What is negligence in Australian law?",
186
+ "Negligence is a legal concept in Australian law that requires a duty of care, breach of that duty, and resulting damage."),
187
+ ("What is a contract?",
188
+ "A contract is a legally binding agreement between parties that creates mutual obligations enforceable by law."),
189
+ ]
190
+
191
+ prompt_parts = []
192
+ for q, a in examples:
193
+ prompt_parts.append(f"Question: {q}\nAnswer: {a}")
194
+
195
+ # Add the actual question
196
+ prompt_parts.append(f"Question: {question}\nAnswer:")
197
+
198
+ return "\n\n".join(prompt_parts)
199
+
200
+ def _extract_answer(self, full_response: str, prompt: str) -> str:
201
+ """
202
+ Extract answer from full response.
203
+
204
+ Args:
205
+ full_response: Complete model response
206
+ prompt: Original prompt
207
+
208
+ Returns:
209
+ Extracted answer text
210
+ """
211
+ # Try multiple extraction methods
212
  if "Answer:" in full_response:
213
+ # Split by "Answer:" and take the last part (in case of multiple)
214
+ parts = full_response.split("Answer:")
215
+ if len(parts) > 1:
216
+ answer = parts[-1].strip()
 
 
217
  else:
 
218
  answer = full_response.strip()
219
+ else:
220
+ # Fallback: remove prompt from response
221
+ answer = full_response.replace(prompt, "").strip()
222
 
223
  return answer
224
 
225
+ def _clean_and_filter_answer(self, answer: str) -> str:
226
+ """
227
+ Aggressively clean and filter gibberish from answer.
228
+
229
+ Args:
230
+ answer: Raw answer text
231
+
232
+ Returns:
233
+ Cleaned answer or empty string if too poor quality
234
+ """
235
+ if not answer:
236
+ return ""
237
+
238
+ # Remove problematic prefixes that don't add value
239
+ problematic_starts = [
240
+ ("Yes.", 4),
241
+ ("Yes,", 4),
242
+ ("Yes ", 4),
243
+ ("I do not know", 13),
244
+ ("I am not sure", 13),
245
+ ("I do and will not", 17),
246
+ ]
247
+
248
+ for prefix, length in problematic_starts:
249
+ if answer.strip().startswith(prefix):
250
+ after = answer[length:].strip()
251
+ # Only remove if there's substantial content after
252
+ if len(after) > 30:
253
+ answer = after
254
+ break
255
+
256
+ # Remove everything after rambling markers
257
+ rambling_markers = ['---', '???', '...', '?---', '\n\nQuestion:', '\nQuestion:']
258
+ for marker in rambling_markers:
259
+ if marker in answer:
260
+ idx = answer.find(marker)
261
+ answer = answer[:idx].strip()
262
+
263
+ # Remove excessive whitespace
264
+ answer = ' '.join(answer.split())
265
+
266
+ # Remove incomplete sentences at the end
267
+ # Keep only complete sentences (ending with . ! or ?)
268
+ sentences = re.split(r'([.!?]\s+)', answer)
269
+ if len(sentences) > 1:
270
+ cleaned = []
271
+ for i in range(0, len(sentences) - 1, 2):
272
+ if i + 1 < len(sentences):
273
+ cleaned.append(sentences[i] + sentences[i + 1])
274
+ if cleaned:
275
+ answer = ''.join(cleaned).strip()
276
+
277
+ # Filter out if too short
278
+ if len(answer) < 30:
279
+ return ""
280
+
281
+ # Check for excessive repetition (gibberish indicator)
282
+ words = answer.split()
283
+ if len(words) > 0:
284
+ unique_ratio = len(set(words)) / len(words)
285
+ if unique_ratio < 0.3: # More than 70% repetition
286
+ return ""
287
+
288
+ return answer
289
+
290
+ def _score_answer_quality(self, answer: str) -> float:
291
+ """
292
+ Score answer quality (higher is better).
293
+
294
+ Args:
295
+ answer: Answer text to score
296
+
297
+ Returns:
298
+ Quality score
299
+ """
300
+ if not answer or len(answer) < 20:
301
+ return -100
302
+
303
+ score = 0
304
+
305
+ # Reward reasonable length (sweet spot around 200-400 chars)
306
+ length = len(answer)
307
+ if 100 <= length <= 500:
308
+ score += 50
309
+ elif 50 <= length < 100:
310
+ score += 30
311
+ elif length > 500:
312
+ score += 40 # Slightly less for very long
313
+ else:
314
+ score -= 20
315
+
316
+ # Penalize common gibberish patterns
317
+ gibberish_patterns = ['---', '???', '...', '?---', 'I do not know', 'I am not sure']
318
+ for pattern in gibberish_patterns:
319
+ if pattern in answer:
320
+ score -= 30
321
+
322
+ # Penalize if starts with "Yes." and nothing substantial
323
+ if answer.strip().startswith("Yes.") and len(answer.strip()) < 50:
324
+ score -= 40
325
+
326
+ # Reward complete sentences
327
+ sentence_count = answer.count('. ') + answer.count('? ') + answer.count('! ')
328
+ score += min(sentence_count * 5, 30)
329
+
330
+ # Reward diversity (less repetition)
331
+ words = answer.split()
332
+ if len(words) > 0:
333
+ unique_ratio = len(set(words)) / len(words)
334
+ score += unique_ratio * 30
335
+
336
+ # Penalize excessive question marks (uncertainty)
337
+ if answer.count('?') > 3:
338
+ score -= 20
339
+
340
+ # Reward legal terminology (common legal words)
341
+ legal_terms = ['law', 'legal', 'court', 'act', 'section', 'australia', 'australian',
342
+ 'right', 'obligation', 'contract', 'negligence', 'liability', 'duty']
343
+ term_count = sum(1 for term in legal_terms if term.lower() in answer.lower())
344
+ score += min(term_count * 3, 15)
345
+
346
+ return score
347
+
348
+ def _find_natural_stopping_point(self, text: str, max_chars: int = 600) -> str:
349
+ """
350
+ Find a natural stopping point in text to prevent abrupt cuts.
351
+
352
+ Args:
353
+ text: Text to truncate
354
+ max_chars: Maximum character length
355
+
356
+ Returns:
357
+ Text truncated at natural boundary
358
+ """
359
+ if len(text) <= max_chars:
360
+ return text
361
+
362
+ # Try to cut at sentence boundary
363
+ truncated = text[:max_chars]
364
+ sentence_endings = ['. ', '.\n', '? ', '!\n', '! ']
365
+
366
+ for ending in sentence_endings:
367
+ idx = truncated.rfind(ending)
368
+ # If found in last 30% of truncated text, use it
369
+ if idx > max_chars * 0.7:
370
+ return truncated[:idx + 1].strip()
371
+
372
+ # Fallback: cut at word boundary
373
+ words = truncated.rsplit(' ', 1)
374
+ if len(words) > 1:
375
+ return words[0] + '...'
376
+
377
+ return truncated + '...'
378
+
379
  def interactive_query(self):
380
  """Run interactive query loop."""
381
  print("\n" + "=" * 80)
 
397
 
398
  print("\nGenerating answer...")
399
  try:
400
+ # Use reranking by default for better quality
401
+ answer = self.generate_answer(
402
+ question,
403
+ temperature=0.2,
404
+ max_length=200,
405
+ num_candidates=3,
406
+ use_reranking=True
407
+ )
408
  print(f"\nAnswer: {answer}\n")
409
  except ValueError as e:
410
  print(f"\nInvalid input: {e}\n")
 
442
  parser.add_argument(
443
  '--temperature',
444
  type=float,
445
+ default=0.2,
446
+ help='Sampling temperature (default: 0.2, lower = more deterministic)'
447
  )
448
  parser.add_argument(
449
  '--max-length',
450
  type=int,
451
+ default=200,
452
+ help='Maximum response length in tokens (default: 200)'
453
+ )
454
+ parser.add_argument(
455
+ '--num-candidates',
456
+ type=int,
457
+ default=3,
458
+ help='Number of candidates to generate for reranking (default: 3)'
459
+ )
460
+ parser.add_argument(
461
+ '--no-reranking',
462
+ action='store_true',
463
+ help='Disable candidate reranking (faster but lower quality)'
464
  )
465
 
466
  args = parser.parse_args()
 
496
  answer = slm.generate_answer(
497
  args.question,
498
  temperature=args.temperature,
499
+ max_length=args.max_length,
500
+ num_candidates=args.num_candidates,
501
+ use_reranking=not args.no_reranking
502
  )
503
  print(f"\nQuestion: {args.question}")
504
  print(f"Answer: {answer}\n")