real-jiakai commited on
Commit
df823ae
1 Parent(s): e158287

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -13
README.md CHANGED
@@ -113,25 +113,71 @@ Limitations:
113
  ## How to use
114
 
115
  ```python
 
116
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
117
 
118
  # Load model & tokenizer
119
- model_name = "your-username/bert-base-uncased-finetuned-squadv2"
120
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
121
  tokenizer = AutoTokenizer.from_pretrained(model_name)
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Example usage
124
  context = "The Apollo program was designed to land humans on the Moon and bring them safely back to Earth."
125
- question = "What was the goal of the Apollo program?"
126
-
127
- # Tokenize input
128
- inputs = tokenizer(
129
- question,
130
- context,
131
- add_special_tokens=True,
132
- return_tensors="pt"
133
- )
134
-
135
- # Get model predictions
136
- outputs = model(**inputs)
137
  ```
 
113
  ## How to use
114
 
115
  ```python
116
+ import torch
117
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
118
 
119
  # Load model & tokenizer
120
+ model_name = "real-jiakai/bert-base-uncased-finetuned-squadv2"
121
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
122
  tokenizer = AutoTokenizer.from_pretrained(model_name)
123
 
124
+ def get_answer_v2(question, context, threshold=0.0):
125
+ # Tokenize input with maximum sequence length of 384
126
+ inputs = tokenizer(
127
+ question,
128
+ context,
129
+ return_tensors="pt",
130
+ max_length=384,
131
+ truncation=True
132
+ )
133
+
134
+ with torch.no_grad():
135
+ outputs = model(**inputs)
136
+ start_logits = outputs.start_logits[0]
137
+ end_logits = outputs.end_logits[0]
138
+
139
+ # Calculate null score (score for predicting no answer)
140
+ null_score = start_logits[0].item() + end_logits[0].item()
141
+
142
+ # Find the best non-null answer, excluding [CLS] position
143
+ # Set logits at [CLS] position to negative infinity
144
+ start_logits[0] = float('-inf')
145
+ end_logits[0] = float('-inf')
146
+
147
+ start_idx = torch.argmax(start_logits)
148
+ end_idx = torch.argmax(end_logits)
149
+
150
+ # Ensure end_idx is not less than start_idx
151
+ if end_idx < start_idx:
152
+ end_idx = start_idx
153
+
154
+ answer_score = start_logits[start_idx].item() + end_logits[end_idx].item()
155
+
156
+ # If null score is higher (beyond threshold), return "no answer"
157
+ if null_score - answer_score > threshold:
158
+ return "Question cannot be answered based on the given context."
159
+
160
+ # Otherwise, return the extracted answer
161
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
162
+ answer = tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1])
163
+
164
+ # Check if answer is empty or contains only special tokens
165
+ if not answer.strip() or answer.strip() in ['[CLS]', '[SEP]']:
166
+ return "Question cannot be answered based on the given context."
167
+
168
+ return answer.strip()
169
+
170
  # Example usage
171
  context = "The Apollo program was designed to land humans on the Moon and bring them safely back to Earth."
172
+ questions = [
173
+ "What was the goal of the Apollo program?",
174
+ "Who was the first person to walk on Mars?", # Unanswerable question
175
+ "What was the Apollo program designed to do?"
176
+ ]
177
+
178
+ for question in questions:
179
+ answer = get_answer_v2(question, context, threshold=1.0)
180
+ print(f"Question: {question}")
181
+ print(f"Answer: {answer}")
182
+ print("-" * 50)
 
183
  ```