JamesANZ commited on
Commit
0506d30
·
1 Parent(s): c7fbee4

Fix attention mask warning by using tokenizer() and passing attention_mask to generate()

Browse files
Files changed (1) hide show
  1. query_slm.py +23 -9
query_slm.py CHANGED
@@ -15,7 +15,7 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
  class LegalSLM:
16
  """Wrapper class for the fine-tuned legal SLM."""
17
 
18
- def __init__(self, model_dir: str = "models/legal_slm"):
19
  """
20
  Initialize the Legal SLM.
21
 
@@ -34,6 +34,12 @@ class LegalSLM:
34
  self.model.to(self.device)
35
  self.model.eval() # Set to evaluation mode
36
 
 
 
 
 
 
 
37
  print("Model loaded successfully!")
38
 
39
  def _sanitize_input(self, text: str, max_length: int = 1000) -> str:
@@ -99,20 +105,28 @@ class LegalSLM:
99
  # Build prompt
100
  prompt = f"Based on Australian legal documents, answer the following.\n\nQuestion: {sanitized_question}\nAnswer:"
101
 
102
- # Tokenize prompt
103
- inputs = self.tokenizer.encode(prompt, return_tensors='pt')
104
- inputs = inputs.to(self.device)
 
 
 
 
 
 
 
105
 
106
  # Generate
107
  with torch.no_grad():
108
  outputs = self.model.generate(
109
- inputs,
110
- max_length=inputs.shape[1] + max_length,
 
111
  temperature=temperature,
112
  top_p=top_p,
113
  top_k=top_k,
114
  do_sample=True,
115
- pad_token_id=self.tokenizer.eos_token_id,
116
  eos_token_id=self.tokenizer.eos_token_id,
117
  repetition_penalty=1.2, # Reduce repetition
118
  )
@@ -181,8 +195,8 @@ def main():
181
  parser.add_argument(
182
  '--model-dir',
183
  type=str,
184
- default='models/legal_slm',
185
- help='Path to fine-tuned model directory'
186
  )
187
  parser.add_argument(
188
  '--question',
 
15
  class LegalSLM:
16
  """Wrapper class for the fine-tuned legal SLM."""
17
 
18
+ def __init__(self, model_dir: str = "."):
19
  """
20
  Initialize the Legal SLM.
21
 
 
34
  self.model.to(self.device)
35
  self.model.eval() # Set to evaluation mode
36
 
37
+ # Fix attention mask warning: GPT-2 uses same token for pad and eos
38
+ # Set pad_token_id explicitly and ensure it's handled correctly
39
+ if self.tokenizer.pad_token is None:
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
42
+
43
  print("Model loaded successfully!")
44
 
45
  def _sanitize_input(self, text: str, max_length: int = 1000) -> str:
 
105
  # Build prompt
106
  prompt = f"Based on Australian legal documents, answer the following.\n\nQuestion: {sanitized_question}\nAnswer:"
107
 
108
+ # Tokenize prompt with attention mask to fix the warning
109
+ # Using tokenizer() instead of encode() to get attention_mask automatically
110
+ tokenized = self.tokenizer(
111
+ prompt,
112
+ return_tensors='pt',
113
+ padding=False, # No padding needed for single input
114
+ truncation=False
115
+ )
116
+ input_ids = tokenized['input_ids'].to(self.device)
117
+ attention_mask = tokenized['attention_mask'].to(self.device)
118
 
119
  # Generate
120
  with torch.no_grad():
121
  outputs = self.model.generate(
122
+ input_ids,
123
+ attention_mask=attention_mask, # Pass attention mask to fix warning
124
+ max_length=input_ids.shape[1] + max_length,
125
  temperature=temperature,
126
  top_p=top_p,
127
  top_k=top_k,
128
  do_sample=True,
129
+ pad_token_id=self.tokenizer.pad_token_id,
130
  eos_token_id=self.tokenizer.eos_token_id,
131
  repetition_penalty=1.2, # Reduce repetition
132
  )
 
195
  parser.add_argument(
196
  '--model-dir',
197
  type=str,
198
+ default='.',
199
+ help='Path to fine-tuned model directory (default: current directory)'
200
  )
201
  parser.add_argument(
202
  '--question',