File size: 6,398 Bytes
d868172 357be93 7d0539f d868172 357be93 d868172 9f3ce07 d868172 9f3ce07 d868172 9f3ce07 4697040 d868172 516cd0a d868172 516cd0a 9f3ce07 d868172 516cd0a d868172 80a6ddf d868172 357be93 d868172 a945a9c bd79886 d868172 69c8f9c d868172 b1746af 0ed70a2 69c8f9c 0ed70a2 3df5bff f437981 1ea9d7f f437981 1ea9d7f bffe103 0ed70a2 987f96d 0ed70a2 f9cb0bf 357be93 0ed70a2 d868172 357be93 0ed70a2 3cf0565 c9bfd8b 0ed70a2 c9bfd8b 0ed70a2 d868172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from transformers import Pipeline
import nltk
import requests
import torch
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
NEL_MODEL = "nel-mgenre-multilingual"
def get_wikipedia_page_props(input_str: str):
"""
Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
If the request fails, it falls back to using the OpenRefine Wikidata API.
Args:
input_str (str): The input string in the format "page_name >> language".
Returns:
str: The QID or "NIL" if the QID is not found.
"""
# print(f"Input string: {input_str}")
if ">>" not in input_str:
page_name = input_str
language = "en"
print(
f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
)
else:
# Preprocess the input string
try:
page_name, language = input_str.split(">>")
page_name = page_name.strip()
language = language.strip()
except:
page_name = input_str
language = "en"
print(
f"<< was not found in {input_str} so we are checking with these values: Page name: {page_name}, Language: {language}"
)
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
wikipedia_params = {
"action": "query",
"prop": "pageprops",
"format": "json",
"titles": page_name,
}
qid = "NIL"
try:
# Attempt to fetch from Wikipedia API
response = requests.get(wikipedia_url, params=wikipedia_params)
response.raise_for_status()
data = response.json()
if "pages" in data["query"]:
page_id = list(data["query"]["pages"].keys())[0]
if "pageprops" in data["query"]["pages"][page_id]:
page_props = data["query"]["pages"][page_id]["pageprops"]
if "wikibase_item" in page_props:
# print(page_props["wikibase_item"], language)
return page_props["wikibase_item"], language
else:
return qid, language
else:
return qid, language
else:
return qid, language
except Exception as e:
return qid, language
def get_wikipedia_title(qid, language="en"):
url = f"https://www.wikidata.org/w/api.php"
params = {
"action": "wbgetentities",
"format": "json",
"ids": qid,
"props": "sitelinks/urls",
"sitefilter": f"{language}wiki",
}
response = requests.get(url, params=params)
try:
response.raise_for_status() # Raise an HTTPError if the response was not 2xx
data = response.json()
except requests.exceptions.RequestException as e:
print(f"HTTP error: {e}")
return "NIL", "None"
except ValueError as e: # Catch JSON decode errors
print(f"Invalid JSON response: {response.text}")
return "NIL", "None"
try:
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
return title, url
except KeyError:
return "NIL", "None"
class NelPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
# Extract the entity between [START] and [END]
start_token = "[START]"
end_token = "[END]"
if start_token in text and end_token in text:
start_idx = text.index(start_token) + len(start_token)
end_idx = text.index(end_token)
enclosed_entity = text[start_idx:end_idx].strip()
lOffset = start_idx # left offset (start of the entity)
rOffset = end_idx # right offset (end of the entity)
else:
enclosed_entity = None
lOffset = None
rOffset = None
# Generate predictions using the model
outputs = self.model.generate(
**self.tokenizer([text], return_tensors="pt").to(self.device),
num_beams=1,
num_return_sequences=1,
max_new_tokens=30,
return_dict_in_generate=True,
output_scores=True,
)
# Decode the predictions into readable text
wikipedia_prediction = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)[0]
# Process the scores for each token
transition_scores = self.model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
)
log_prob_sum = sum(transition_scores[0])
# Calculate the probability for the entire sequence by exponentiating the sum of log probabilities
sequence_confidence = torch.exp(log_prob_sum)
percentage = sequence_confidence.cpu().numpy() * 100.0
# print(wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage)
# Return the predictions along with the extracted entity, lOffset, and rOffset
return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
def _forward(self, inputs):
return inputs
def postprocess(self, outputs, **kwargs):
"""
Postprocess the outputs of the model
:param outputs:
:param kwargs:
:return:
"""
wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
qid, language = get_wikipedia_page_props(wikipedia_prediction)
title, url = get_wikipedia_title(qid, language=language)
percentage = round(percentage, 2)
results = [
{
# "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
"surface": enclosed_entity,
"wkd_id": qid,
"wkpedia_pagename": title,
"wkpedia_url": url,
"type": "UNK",
"confidence_nel": percentage,
"lOffset": lOffset,
"rOffset": rOffset,
}
]
return results
|