Fix attention mask warning by using tokenizer() and passing attention_mask to generate()
Browse files- query_slm.py +23 -9
query_slm.py
CHANGED
|
@@ -15,7 +15,7 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
| 15 |
class LegalSLM:
|
| 16 |
"""Wrapper class for the fine-tuned legal SLM."""
|
| 17 |
|
| 18 |
-
def __init__(self, model_dir: str = "
|
| 19 |
"""
|
| 20 |
Initialize the Legal SLM.
|
| 21 |
|
|
@@ -34,6 +34,12 @@ class LegalSLM:
|
|
| 34 |
self.model.to(self.device)
|
| 35 |
self.model.eval() # Set to evaluation mode
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
print("Model loaded successfully!")
|
| 38 |
|
| 39 |
def _sanitize_input(self, text: str, max_length: int = 1000) -> str:
|
|
@@ -99,20 +105,28 @@ class LegalSLM:
|
|
| 99 |
# Build prompt
|
| 100 |
prompt = f"Based on Australian legal documents, answer the following.\n\nQuestion: {sanitized_question}\nAnswer:"
|
| 101 |
|
| 102 |
-
# Tokenize prompt
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Generate
|
| 107 |
with torch.no_grad():
|
| 108 |
outputs = self.model.generate(
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
temperature=temperature,
|
| 112 |
top_p=top_p,
|
| 113 |
top_k=top_k,
|
| 114 |
do_sample=True,
|
| 115 |
-
pad_token_id=self.tokenizer.
|
| 116 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 117 |
repetition_penalty=1.2, # Reduce repetition
|
| 118 |
)
|
|
@@ -181,8 +195,8 @@ def main():
|
|
| 181 |
parser.add_argument(
|
| 182 |
'--model-dir',
|
| 183 |
type=str,
|
| 184 |
-
default='
|
| 185 |
-
help='Path to fine-tuned model directory'
|
| 186 |
)
|
| 187 |
parser.add_argument(
|
| 188 |
'--question',
|
|
|
|
| 15 |
class LegalSLM:
|
| 16 |
"""Wrapper class for the fine-tuned legal SLM."""
|
| 17 |
|
| 18 |
+
def __init__(self, model_dir: str = "."):
|
| 19 |
"""
|
| 20 |
Initialize the Legal SLM.
|
| 21 |
|
|
|
|
| 34 |
self.model.to(self.device)
|
| 35 |
self.model.eval() # Set to evaluation mode
|
| 36 |
|
| 37 |
+
# Fix attention mask warning: GPT-2 uses same token for pad and eos
|
| 38 |
+
# Set pad_token_id explicitly and ensure it's handled correctly
|
| 39 |
+
if self.tokenizer.pad_token is None:
|
| 40 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 41 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 42 |
+
|
| 43 |
print("Model loaded successfully!")
|
| 44 |
|
| 45 |
def _sanitize_input(self, text: str, max_length: int = 1000) -> str:
|
|
|
|
| 105 |
# Build prompt
|
| 106 |
prompt = f"Based on Australian legal documents, answer the following.\n\nQuestion: {sanitized_question}\nAnswer:"
|
| 107 |
|
| 108 |
+
# Tokenize prompt with attention mask to fix the warning
|
| 109 |
+
# Using tokenizer() instead of encode() to get attention_mask automatically
|
| 110 |
+
tokenized = self.tokenizer(
|
| 111 |
+
prompt,
|
| 112 |
+
return_tensors='pt',
|
| 113 |
+
padding=False, # No padding needed for single input
|
| 114 |
+
truncation=False
|
| 115 |
+
)
|
| 116 |
+
input_ids = tokenized['input_ids'].to(self.device)
|
| 117 |
+
attention_mask = tokenized['attention_mask'].to(self.device)
|
| 118 |
|
| 119 |
# Generate
|
| 120 |
with torch.no_grad():
|
| 121 |
outputs = self.model.generate(
|
| 122 |
+
input_ids,
|
| 123 |
+
attention_mask=attention_mask, # Pass attention mask to fix warning
|
| 124 |
+
max_length=input_ids.shape[1] + max_length,
|
| 125 |
temperature=temperature,
|
| 126 |
top_p=top_p,
|
| 127 |
top_k=top_k,
|
| 128 |
do_sample=True,
|
| 129 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 130 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 131 |
repetition_penalty=1.2, # Reduce repetition
|
| 132 |
)
|
|
|
|
| 195 |
parser.add_argument(
|
| 196 |
'--model-dir',
|
| 197 |
type=str,
|
| 198 |
+
default='.',
|
| 199 |
+
help='Path to fine-tuned model directory (default: current directory)'
|
| 200 |
)
|
| 201 |
parser.add_argument(
|
| 202 |
'--question',
|