Improve generation quality without retraining: add candidate reranking, few-shot prompts, aggressive filtering, and quality scoring
Browse files- query_slm.py +279 -31
query_slm.py
CHANGED
|
@@ -72,20 +72,24 @@ class LegalSLM:
|
|
| 72 |
def generate_answer(
|
| 73 |
self,
|
| 74 |
question: str,
|
| 75 |
-
temperature: float = 0.
|
| 76 |
-
max_length: int =
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
) -> str:
|
| 80 |
"""
|
| 81 |
-
Generate an answer to a legal question.
|
| 82 |
|
| 83 |
Args:
|
| 84 |
question: The legal question to answer
|
| 85 |
temperature: Sampling temperature (lower = more deterministic)
|
| 86 |
-
max_length: Maximum length of generated response
|
|
|
|
| 87 |
top_p: Nucleus sampling parameter
|
| 88 |
top_k: Top-k sampling parameter
|
|
|
|
| 89 |
|
| 90 |
Returns:
|
| 91 |
Generated answer text
|
|
@@ -101,53 +105,277 @@ class LegalSLM:
|
|
| 101 |
raise ValueError("Temperature must be between 0.0 and 2.0")
|
| 102 |
if max_length < 1 or max_length > 1000:
|
| 103 |
raise ValueError("max_length must be between 1 and 1000")
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
# Build prompt
|
| 106 |
-
prompt =
|
| 107 |
|
| 108 |
-
# Tokenize prompt with attention mask
|
| 109 |
-
# Using tokenizer() instead of encode() to get attention_mask automatically
|
| 110 |
tokenized = self.tokenizer(
|
| 111 |
prompt,
|
| 112 |
return_tensors='pt',
|
| 113 |
-
padding=False,
|
| 114 |
truncation=False
|
| 115 |
)
|
| 116 |
input_ids = tokenized['input_ids'].to(self.device)
|
| 117 |
attention_mask = tokenized['attention_mask'].to(self.device)
|
| 118 |
|
| 119 |
-
# Generate
|
|
|
|
|
|
|
| 120 |
with torch.no_grad():
|
| 121 |
outputs = self.model.generate(
|
| 122 |
input_ids,
|
| 123 |
-
attention_mask=attention_mask,
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
temperature=temperature,
|
| 126 |
top_p=top_p,
|
| 127 |
top_k=top_k,
|
| 128 |
do_sample=True,
|
| 129 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 130 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 131 |
-
repetition_penalty=1.
|
|
|
|
|
|
|
| 132 |
)
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if "Answer:" in full_response:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if len(prompt) <= len(full_response):
|
| 144 |
-
answer = full_response[len(prompt):].strip()
|
| 145 |
else:
|
| 146 |
-
# If prompt is longer (shouldn't happen, but handle gracefully)
|
| 147 |
answer = full_response.strip()
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return answer
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def interactive_query(self):
|
| 152 |
"""Run interactive query loop."""
|
| 153 |
print("\n" + "=" * 80)
|
|
@@ -169,7 +397,14 @@ class LegalSLM:
|
|
| 169 |
|
| 170 |
print("\nGenerating answer...")
|
| 171 |
try:
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
print(f"\nAnswer: {answer}\n")
|
| 174 |
except ValueError as e:
|
| 175 |
print(f"\nInvalid input: {e}\n")
|
|
@@ -207,14 +442,25 @@ def main():
|
|
| 207 |
parser.add_argument(
|
| 208 |
'--temperature',
|
| 209 |
type=float,
|
| 210 |
-
default=0.
|
| 211 |
-
help='Sampling temperature (default: 0.
|
| 212 |
)
|
| 213 |
parser.add_argument(
|
| 214 |
'--max-length',
|
| 215 |
type=int,
|
| 216 |
-
default=
|
| 217 |
-
help='Maximum response length in tokens (default:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
)
|
| 219 |
|
| 220 |
args = parser.parse_args()
|
|
@@ -250,7 +496,9 @@ def main():
|
|
| 250 |
answer = slm.generate_answer(
|
| 251 |
args.question,
|
| 252 |
temperature=args.temperature,
|
| 253 |
-
max_length=args.max_length
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
print(f"\nQuestion: {args.question}")
|
| 256 |
print(f"Answer: {answer}\n")
|
|
|
|
| 72 |
def generate_answer(
|
| 73 |
self,
|
| 74 |
question: str,
|
| 75 |
+
temperature: float = 0.2,
|
| 76 |
+
max_length: int = 200,
|
| 77 |
+
num_candidates: int = 3,
|
| 78 |
+
top_p: float = 0.85,
|
| 79 |
+
top_k: int = 30,
|
| 80 |
+
use_reranking: bool = True
|
| 81 |
) -> str:
|
| 82 |
"""
|
| 83 |
+
Generate an answer to a legal question with quality improvements.
|
| 84 |
|
| 85 |
Args:
|
| 86 |
question: The legal question to answer
|
| 87 |
temperature: Sampling temperature (lower = more deterministic)
|
| 88 |
+
max_length: Maximum length of generated response in tokens
|
| 89 |
+
num_candidates: Number of candidates to generate for reranking
|
| 90 |
top_p: Nucleus sampling parameter
|
| 91 |
top_k: Top-k sampling parameter
|
| 92 |
+
use_reranking: If True, generate multiple candidates and pick best
|
| 93 |
|
| 94 |
Returns:
|
| 95 |
Generated answer text
|
|
|
|
| 105 |
raise ValueError("Temperature must be between 0.0 and 2.0")
|
| 106 |
if max_length < 1 or max_length > 1000:
|
| 107 |
raise ValueError("max_length must be between 1 and 1000")
|
| 108 |
+
if num_candidates < 1 or num_candidates > 10:
|
| 109 |
+
raise ValueError("num_candidates must be between 1 and 10")
|
| 110 |
|
| 111 |
+
# Build prompt with few-shot examples for better quality
|
| 112 |
+
prompt = self._build_few_shot_prompt(sanitized_question)
|
| 113 |
|
| 114 |
+
# Tokenize prompt with attention mask
|
|
|
|
| 115 |
tokenized = self.tokenizer(
|
| 116 |
prompt,
|
| 117 |
return_tensors='pt',
|
| 118 |
+
padding=False,
|
| 119 |
truncation=False
|
| 120 |
)
|
| 121 |
input_ids = tokenized['input_ids'].to(self.device)
|
| 122 |
attention_mask = tokenized['attention_mask'].to(self.device)
|
| 123 |
|
| 124 |
+
# Generate candidates
|
| 125 |
+
num_to_generate = num_candidates if use_reranking else 1
|
| 126 |
+
|
| 127 |
with torch.no_grad():
|
| 128 |
outputs = self.model.generate(
|
| 129 |
input_ids,
|
| 130 |
+
attention_mask=attention_mask,
|
| 131 |
+
max_new_tokens=max_length,
|
| 132 |
+
num_return_sequences=num_to_generate,
|
| 133 |
+
min_new_tokens=30, # Force minimum answer length
|
| 134 |
temperature=temperature,
|
| 135 |
top_p=top_p,
|
| 136 |
top_k=top_k,
|
| 137 |
do_sample=True,
|
| 138 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 139 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 140 |
+
repetition_penalty=1.4, # Higher to reduce repetition
|
| 141 |
+
no_repeat_ngram_size=4, # Prevent 4-gram repetition
|
| 142 |
+
early_stopping=False,
|
| 143 |
)
|
| 144 |
|
| 145 |
+
# Extract and process candidates
|
| 146 |
+
candidates = []
|
| 147 |
+
for output in outputs:
|
| 148 |
+
full_response = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 149 |
+
answer = self._extract_answer(full_response, prompt)
|
| 150 |
+
answer = self._clean_and_filter_answer(answer)
|
| 151 |
+
|
| 152 |
+
# Only consider answers that pass quality checks
|
| 153 |
+
if answer and len(answer.strip()) > 30:
|
| 154 |
+
if use_reranking:
|
| 155 |
+
score = self._score_answer_quality(answer)
|
| 156 |
+
candidates.append((score, answer))
|
| 157 |
+
else:
|
| 158 |
+
# Return first valid answer if not reranking
|
| 159 |
+
return self._find_natural_stopping_point(answer, max_chars=600)
|
| 160 |
+
|
| 161 |
+
# If no valid candidates, return fallback
|
| 162 |
+
if not candidates:
|
| 163 |
+
return "I cannot provide a reliable answer to this question based on the available information."
|
| 164 |
+
|
| 165 |
+
# Return best candidate based on quality score
|
| 166 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 167 |
+
best_answer = candidates[0][1]
|
| 168 |
+
|
| 169 |
+
# Final cleanup and length limit
|
| 170 |
+
best_answer = self._find_natural_stopping_point(best_answer, max_chars=600)
|
| 171 |
+
return best_answer
|
| 172 |
+
|
| 173 |
+
def _build_few_shot_prompt(self, question: str) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Build prompt with few-shot examples to guide the model.
|
| 176 |
|
| 177 |
+
Args:
|
| 178 |
+
question: The user's question
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Formatted prompt with examples
|
| 182 |
+
"""
|
| 183 |
+
# Few-shot examples that demonstrate good answer format
|
| 184 |
+
examples = [
|
| 185 |
+
("What is negligence in Australian law?",
|
| 186 |
+
"Negligence is a legal concept in Australian law that requires a duty of care, breach of that duty, and resulting damage."),
|
| 187 |
+
("What is a contract?",
|
| 188 |
+
"A contract is a legally binding agreement between parties that creates mutual obligations enforceable by law."),
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
prompt_parts = []
|
| 192 |
+
for q, a in examples:
|
| 193 |
+
prompt_parts.append(f"Question: {q}\nAnswer: {a}")
|
| 194 |
+
|
| 195 |
+
# Add the actual question
|
| 196 |
+
prompt_parts.append(f"Question: {question}\nAnswer:")
|
| 197 |
+
|
| 198 |
+
return "\n\n".join(prompt_parts)
|
| 199 |
+
|
| 200 |
+
def _extract_answer(self, full_response: str, prompt: str) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Extract answer from full response.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
full_response: Complete model response
|
| 206 |
+
prompt: Original prompt
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Extracted answer text
|
| 210 |
+
"""
|
| 211 |
+
# Try multiple extraction methods
|
| 212 |
if "Answer:" in full_response:
|
| 213 |
+
# Split by "Answer:" and take the last part (in case of multiple)
|
| 214 |
+
parts = full_response.split("Answer:")
|
| 215 |
+
if len(parts) > 1:
|
| 216 |
+
answer = parts[-1].strip()
|
|
|
|
|
|
|
| 217 |
else:
|
|
|
|
| 218 |
answer = full_response.strip()
|
| 219 |
+
else:
|
| 220 |
+
# Fallback: remove prompt from response
|
| 221 |
+
answer = full_response.replace(prompt, "").strip()
|
| 222 |
|
| 223 |
return answer
|
| 224 |
|
| 225 |
+
def _clean_and_filter_answer(self, answer: str) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Aggressively clean and filter gibberish from answer.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
answer: Raw answer text
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Cleaned answer or empty string if too poor quality
|
| 234 |
+
"""
|
| 235 |
+
if not answer:
|
| 236 |
+
return ""
|
| 237 |
+
|
| 238 |
+
# Remove problematic prefixes that don't add value
|
| 239 |
+
problematic_starts = [
|
| 240 |
+
("Yes.", 4),
|
| 241 |
+
("Yes,", 4),
|
| 242 |
+
("Yes ", 4),
|
| 243 |
+
("I do not know", 13),
|
| 244 |
+
("I am not sure", 13),
|
| 245 |
+
("I do and will not", 17),
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
for prefix, length in problematic_starts:
|
| 249 |
+
if answer.strip().startswith(prefix):
|
| 250 |
+
after = answer[length:].strip()
|
| 251 |
+
# Only remove if there's substantial content after
|
| 252 |
+
if len(after) > 30:
|
| 253 |
+
answer = after
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
# Remove everything after rambling markers
|
| 257 |
+
rambling_markers = ['---', '???', '...', '?---', '\n\nQuestion:', '\nQuestion:']
|
| 258 |
+
for marker in rambling_markers:
|
| 259 |
+
if marker in answer:
|
| 260 |
+
idx = answer.find(marker)
|
| 261 |
+
answer = answer[:idx].strip()
|
| 262 |
+
|
| 263 |
+
# Remove excessive whitespace
|
| 264 |
+
answer = ' '.join(answer.split())
|
| 265 |
+
|
| 266 |
+
# Remove incomplete sentences at the end
|
| 267 |
+
# Keep only complete sentences (ending with . ! or ?)
|
| 268 |
+
sentences = re.split(r'([.!?]\s+)', answer)
|
| 269 |
+
if len(sentences) > 1:
|
| 270 |
+
cleaned = []
|
| 271 |
+
for i in range(0, len(sentences) - 1, 2):
|
| 272 |
+
if i + 1 < len(sentences):
|
| 273 |
+
cleaned.append(sentences[i] + sentences[i + 1])
|
| 274 |
+
if cleaned:
|
| 275 |
+
answer = ''.join(cleaned).strip()
|
| 276 |
+
|
| 277 |
+
# Filter out if too short
|
| 278 |
+
if len(answer) < 30:
|
| 279 |
+
return ""
|
| 280 |
+
|
| 281 |
+
# Check for excessive repetition (gibberish indicator)
|
| 282 |
+
words = answer.split()
|
| 283 |
+
if len(words) > 0:
|
| 284 |
+
unique_ratio = len(set(words)) / len(words)
|
| 285 |
+
if unique_ratio < 0.3: # More than 70% repetition
|
| 286 |
+
return ""
|
| 287 |
+
|
| 288 |
+
return answer
|
| 289 |
+
|
| 290 |
+
def _score_answer_quality(self, answer: str) -> float:
|
| 291 |
+
"""
|
| 292 |
+
Score answer quality (higher is better).
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
answer: Answer text to score
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Quality score
|
| 299 |
+
"""
|
| 300 |
+
if not answer or len(answer) < 20:
|
| 301 |
+
return -100
|
| 302 |
+
|
| 303 |
+
score = 0
|
| 304 |
+
|
| 305 |
+
# Reward reasonable length (sweet spot around 200-400 chars)
|
| 306 |
+
length = len(answer)
|
| 307 |
+
if 100 <= length <= 500:
|
| 308 |
+
score += 50
|
| 309 |
+
elif 50 <= length < 100:
|
| 310 |
+
score += 30
|
| 311 |
+
elif length > 500:
|
| 312 |
+
score += 40 # Slightly less for very long
|
| 313 |
+
else:
|
| 314 |
+
score -= 20
|
| 315 |
+
|
| 316 |
+
# Penalize common gibberish patterns
|
| 317 |
+
gibberish_patterns = ['---', '???', '...', '?---', 'I do not know', 'I am not sure']
|
| 318 |
+
for pattern in gibberish_patterns:
|
| 319 |
+
if pattern in answer:
|
| 320 |
+
score -= 30
|
| 321 |
+
|
| 322 |
+
# Penalize if starts with "Yes." and nothing substantial
|
| 323 |
+
if answer.strip().startswith("Yes.") and len(answer.strip()) < 50:
|
| 324 |
+
score -= 40
|
| 325 |
+
|
| 326 |
+
# Reward complete sentences
|
| 327 |
+
sentence_count = answer.count('. ') + answer.count('? ') + answer.count('! ')
|
| 328 |
+
score += min(sentence_count * 5, 30)
|
| 329 |
+
|
| 330 |
+
# Reward diversity (less repetition)
|
| 331 |
+
words = answer.split()
|
| 332 |
+
if len(words) > 0:
|
| 333 |
+
unique_ratio = len(set(words)) / len(words)
|
| 334 |
+
score += unique_ratio * 30
|
| 335 |
+
|
| 336 |
+
# Penalize excessive question marks (uncertainty)
|
| 337 |
+
if answer.count('?') > 3:
|
| 338 |
+
score -= 20
|
| 339 |
+
|
| 340 |
+
# Reward legal terminology (common legal words)
|
| 341 |
+
legal_terms = ['law', 'legal', 'court', 'act', 'section', 'australia', 'australian',
|
| 342 |
+
'right', 'obligation', 'contract', 'negligence', 'liability', 'duty']
|
| 343 |
+
term_count = sum(1 for term in legal_terms if term.lower() in answer.lower())
|
| 344 |
+
score += min(term_count * 3, 15)
|
| 345 |
+
|
| 346 |
+
return score
|
| 347 |
+
|
| 348 |
+
def _find_natural_stopping_point(self, text: str, max_chars: int = 600) -> str:
|
| 349 |
+
"""
|
| 350 |
+
Find a natural stopping point in text to prevent abrupt cuts.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
text: Text to truncate
|
| 354 |
+
max_chars: Maximum character length
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Text truncated at natural boundary
|
| 358 |
+
"""
|
| 359 |
+
if len(text) <= max_chars:
|
| 360 |
+
return text
|
| 361 |
+
|
| 362 |
+
# Try to cut at sentence boundary
|
| 363 |
+
truncated = text[:max_chars]
|
| 364 |
+
sentence_endings = ['. ', '.\n', '? ', '!\n', '! ']
|
| 365 |
+
|
| 366 |
+
for ending in sentence_endings:
|
| 367 |
+
idx = truncated.rfind(ending)
|
| 368 |
+
# If found in last 30% of truncated text, use it
|
| 369 |
+
if idx > max_chars * 0.7:
|
| 370 |
+
return truncated[:idx + 1].strip()
|
| 371 |
+
|
| 372 |
+
# Fallback: cut at word boundary
|
| 373 |
+
words = truncated.rsplit(' ', 1)
|
| 374 |
+
if len(words) > 1:
|
| 375 |
+
return words[0] + '...'
|
| 376 |
+
|
| 377 |
+
return truncated + '...'
|
| 378 |
+
|
| 379 |
def interactive_query(self):
|
| 380 |
"""Run interactive query loop."""
|
| 381 |
print("\n" + "=" * 80)
|
|
|
|
| 397 |
|
| 398 |
print("\nGenerating answer...")
|
| 399 |
try:
|
| 400 |
+
# Use reranking by default for better quality
|
| 401 |
+
answer = self.generate_answer(
|
| 402 |
+
question,
|
| 403 |
+
temperature=0.2,
|
| 404 |
+
max_length=200,
|
| 405 |
+
num_candidates=3,
|
| 406 |
+
use_reranking=True
|
| 407 |
+
)
|
| 408 |
print(f"\nAnswer: {answer}\n")
|
| 409 |
except ValueError as e:
|
| 410 |
print(f"\nInvalid input: {e}\n")
|
|
|
|
| 442 |
parser.add_argument(
|
| 443 |
'--temperature',
|
| 444 |
type=float,
|
| 445 |
+
default=0.2,
|
| 446 |
+
help='Sampling temperature (default: 0.2, lower = more deterministic)'
|
| 447 |
)
|
| 448 |
parser.add_argument(
|
| 449 |
'--max-length',
|
| 450 |
type=int,
|
| 451 |
+
default=200,
|
| 452 |
+
help='Maximum response length in tokens (default: 200)'
|
| 453 |
+
)
|
| 454 |
+
parser.add_argument(
|
| 455 |
+
'--num-candidates',
|
| 456 |
+
type=int,
|
| 457 |
+
default=3,
|
| 458 |
+
help='Number of candidates to generate for reranking (default: 3)'
|
| 459 |
+
)
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
'--no-reranking',
|
| 462 |
+
action='store_true',
|
| 463 |
+
help='Disable candidate reranking (faster but lower quality)'
|
| 464 |
)
|
| 465 |
|
| 466 |
args = parser.parse_args()
|
|
|
|
| 496 |
answer = slm.generate_answer(
|
| 497 |
args.question,
|
| 498 |
temperature=args.temperature,
|
| 499 |
+
max_length=args.max_length,
|
| 500 |
+
num_candidates=args.num_candidates,
|
| 501 |
+
use_reranking=not args.no_reranking
|
| 502 |
)
|
| 503 |
print(f"\nQuestion: {args.question}")
|
| 504 |
print(f"Answer: {answer}\n")
|