|
import logging |
|
from transformers import Pipeline |
|
import numpy as np |
|
import torch |
|
import nltk |
|
|
|
nltk.download("averaged_perceptron_tagger") |
|
nltk.download("averaged_perceptron_tagger_eng") |
|
nltk.download("stopwords") |
|
from nltk.chunk import conlltags2tree |
|
from nltk import pos_tag |
|
from nltk.tree import Tree |
|
import torch.nn.functional as F |
|
import re, string |
|
|
|
stop_words = set(nltk.corpus.stopwords.words("english")) |
|
DEBUG = False |
|
punctuation = ( |
|
string.punctuation |
|
+ "«»—…“”" |
|
+ "—." |
|
+ "–" |
|
+ "’" |
|
+ "‘" |
|
+ "´" |
|
+ "•" |
|
+ "°" |
|
+ "»" |
|
+ "“" |
|
+ "”" |
|
+ "–" |
|
+ "—" |
|
+ "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
WHITESPACE_RULES = { |
|
"fr": { |
|
"pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"], |
|
"pct_no_ws_after": ["(", "[", "{"], |
|
"pct_no_ws_before_after": ["'", "-"], |
|
"pct_number": [".", ","], |
|
}, |
|
"de": { |
|
"pct_no_ws_before": [ |
|
".", |
|
",", |
|
")", |
|
"]", |
|
"}", |
|
"°", |
|
"...", |
|
"?", |
|
"!", |
|
":", |
|
";", |
|
".-", |
|
"%", |
|
], |
|
"pct_no_ws_after": ["(", "[", "{"], |
|
"pct_no_ws_before_after": ["'", "-"], |
|
"pct_number": [".", ","], |
|
}, |
|
"other": { |
|
"pct_no_ws_before": [ |
|
".", |
|
",", |
|
")", |
|
"]", |
|
"}", |
|
"°", |
|
"...", |
|
"?", |
|
"!", |
|
":", |
|
";", |
|
".-", |
|
"%", |
|
], |
|
"pct_no_ws_after": ["(", "[", "{"], |
|
"pct_no_ws_before_after": ["'", "-"], |
|
"pct_number": [".", ","], |
|
}, |
|
} |
|
|
|
|
|
def tokenize(text: str, language: str = "other") -> list[str]: |
|
"""Apply whitespace rules to the given text and language, separating it into tokens. |
|
|
|
Args: |
|
text (str): The input text to separate into a list of tokens. |
|
language (str): Language of the text. |
|
|
|
Returns: |
|
list[str]: List of tokens with punctuation as separate tokens. |
|
""" |
|
|
|
if not text: |
|
return [] |
|
|
|
if language not in WHITESPACE_RULES: |
|
|
|
|
|
language = "other" |
|
|
|
wsrules = WHITESPACE_RULES[language] |
|
tokenized_text = [] |
|
current_token = "" |
|
|
|
for char in text: |
|
if char in wsrules["pct_no_ws_before_after"]: |
|
if current_token: |
|
tokenized_text.append(current_token) |
|
tokenized_text.append(char) |
|
current_token = "" |
|
elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]: |
|
if current_token: |
|
tokenized_text.append(current_token) |
|
tokenized_text.append(char) |
|
current_token = "" |
|
elif char.isspace(): |
|
if current_token: |
|
tokenized_text.append(current_token) |
|
current_token = "" |
|
else: |
|
current_token += char |
|
|
|
if current_token: |
|
tokenized_text.append(current_token) |
|
|
|
return tokenized_text |
|
|
|
|
|
def normalize_text(text): |
|
|
|
return re.sub(r"[ \t]+", "", text) |
|
|
|
|
|
def find_entity_indices(article_text, search_text): |
|
|
|
normalized_article = normalize_text(article_text) |
|
normalized_search = normalize_text(search_text) |
|
|
|
|
|
indices = [] |
|
|
|
|
|
start_index = 0 |
|
while True: |
|
start_index = normalized_article.find(normalized_search, start_index) |
|
if start_index == -1: |
|
break |
|
|
|
|
|
original_chars = 0 |
|
original_start_index = 0 |
|
for i in range(start_index): |
|
while article_text[original_start_index] in (" ", "\t"): |
|
original_start_index += 1 |
|
if article_text[original_start_index] not in (" ", "\t", "\n"): |
|
original_chars += 1 |
|
original_start_index += 1 |
|
|
|
original_end_index = original_start_index |
|
search_chars = 0 |
|
while search_chars < len(normalized_search): |
|
if article_text[original_end_index] not in (" ", "\t", "\n"): |
|
search_chars += 1 |
|
original_end_index += 1 |
|
|
|
|
|
if article_text[original_start_index] == " ": |
|
original_start_index += 1 |
|
indices.append((original_start_index, original_end_index)) |
|
|
|
|
|
start_index += 1 |
|
|
|
return indices |
|
|
|
|
|
def get_entities(tokens, tags, confidences, text): |
|
|
|
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags] |
|
pos_tags = [pos for token, pos in pos_tag(tokens)] |
|
|
|
for i in range(1, len(tags)): |
|
|
|
if tags[i].startswith("B-") and tags[i - 1].startswith("I-"): |
|
tags[i] = "I-" + tags[i][2:] |
|
|
|
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)] |
|
ne_tree = conlltags2tree(conlltags) |
|
|
|
entities = [] |
|
idx: int = 0 |
|
already_done = [] |
|
for subtree in ne_tree: |
|
|
|
if isinstance(subtree, Tree): |
|
original_label = subtree.label() |
|
original_string = " ".join([token for token, pos in subtree.leaves()]) |
|
|
|
for indices in find_entity_indices(text, original_string): |
|
entity_start_position = indices[0] |
|
entity_end_position = indices[1] |
|
if ( |
|
"_".join( |
|
[original_label, original_string, str(entity_start_position)] |
|
) |
|
in already_done |
|
): |
|
continue |
|
else: |
|
already_done.append( |
|
"_".join( |
|
[ |
|
original_label, |
|
original_string, |
|
str(entity_start_position), |
|
] |
|
) |
|
) |
|
if len(text[entity_start_position:entity_end_position].strip()) < len( |
|
text[entity_start_position:entity_end_position] |
|
): |
|
entity_start_position = ( |
|
entity_start_position |
|
+ len(text[entity_start_position:entity_end_position]) |
|
- len(text[entity_start_position:entity_end_position].strip()) |
|
) |
|
|
|
entities.append( |
|
{ |
|
"type": original_label, |
|
"confidence_ner": round( |
|
np.average(confidences[idx : idx + len(subtree)]) * 100, 2 |
|
), |
|
"index": (idx, idx + len(subtree)), |
|
"surface": text[ |
|
entity_start_position:entity_end_position |
|
], |
|
"lOffset": entity_start_position, |
|
"rOffset": entity_end_position, |
|
} |
|
) |
|
|
|
idx += len(subtree) |
|
|
|
|
|
|
|
else: |
|
token, pos = subtree |
|
|
|
|
|
idx += 1 |
|
|
|
return entities |
|
|
|
|
|
def realign( |
|
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map |
|
): |
|
preds_list, words_list, confidence_list = [], [], [] |
|
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids() |
|
for idx, word in enumerate(text_sentence): |
|
beginning_index = word_ids.index(idx) |
|
try: |
|
preds_list.append(reverted_label_map[out_label_preds[beginning_index]]) |
|
confidence_list.append(max(softmax_scores[beginning_index])) |
|
except Exception as ex: |
|
preds_list.append("O") |
|
confidence_list.append(0.0) |
|
words_list.append(word) |
|
|
|
return words_list, preds_list, confidence_list |
|
|
|
|
|
def add_spaces_around_punctuation(text): |
|
|
|
all_punctuation = string.punctuation + punctuation |
|
return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text) |
|
|
|
|
|
def attach_comp_to_closest(entities): |
|
|
|
valid_entity_types = {"org", "pers", "org.ent", "pers.ind"} |
|
|
|
|
|
comp_entities = [ent for ent in entities if ent["type"].startswith("comp")] |
|
other_entities = [ent for ent in entities if not ent["type"].startswith("comp")] |
|
|
|
for comp_entity in comp_entities: |
|
closest_entity = None |
|
min_distance = float("inf") |
|
|
|
|
|
for other_entity in other_entities: |
|
|
|
if comp_entity["lOffset"] > other_entity["rOffset"]: |
|
distance = comp_entity["lOffset"] - other_entity["rOffset"] |
|
elif comp_entity["rOffset"] < other_entity["lOffset"]: |
|
distance = other_entity["lOffset"] - comp_entity["rOffset"] |
|
else: |
|
distance = 0 |
|
|
|
|
|
if ( |
|
distance < min_distance |
|
and other_entity["type"].split(".")[0] in valid_entity_types |
|
): |
|
min_distance = distance |
|
closest_entity = other_entity |
|
|
|
|
|
if closest_entity: |
|
suffix = comp_entity["type"].split(".")[ |
|
-1 |
|
] |
|
closest_entity[suffix] = comp_entity["surface"] |
|
|
|
return other_entities |
|
|
|
|
|
def conflicting_context(comp_entity, target_entity): |
|
""" |
|
Determines if there is a conflict between the comp_entity and the target entity. |
|
Prevents incorrect name and function attachments by using a rule-based approach. |
|
""" |
|
|
|
if comp_entity["type"].startswith("comp.function"): |
|
if not ("pers" in target_entity["type"] or "org" in target_entity["type"]): |
|
return True |
|
|
|
|
|
if "loc" in target_entity["type"]: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def extract_name_from_text(text, partial_name): |
|
""" |
|
Extracts the full name from the entity's text based on the partial name. |
|
This function assumes that the full name starts with capitalized letters and does not |
|
include any words that come after the partial name. |
|
""" |
|
|
|
words = tokenize(text) |
|
partial_words = partial_name.split() |
|
|
|
if DEBUG: |
|
print("text:", text) |
|
if DEBUG: |
|
print("partial_name:", partial_name) |
|
|
|
|
|
for i, word in enumerate(words): |
|
if DEBUG: |
|
print(words, "---", words[i : i + len(partial_words)]) |
|
if words[i : i + len(partial_words)] == partial_words: |
|
|
|
full_name = partial_words[:] |
|
|
|
if DEBUG: |
|
print("full_name:", full_name) |
|
|
|
|
|
j = i - 1 |
|
while j >= 0 and words[j][0].isupper(): |
|
full_name.insert(0, words[j]) |
|
j -= 1 |
|
if DEBUG: |
|
print("full_name:", full_name) |
|
|
|
|
|
return " ".join(full_name).strip() |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def repair_names_in_entities(entities): |
|
""" |
|
This function repairs the names in the entities by extracting the full name |
|
from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached. |
|
""" |
|
for entity in entities: |
|
if "name" in entity and "pers" in entity["type"]: |
|
name = entity["name"] |
|
text = entity["surface"] |
|
|
|
|
|
if name in text: |
|
|
|
full_name = extract_name_from_text(entity["surface"], name) |
|
entity["name"] = ( |
|
full_name |
|
) |
|
|
|
|
|
|
|
return entities |
|
|
|
|
|
def clean_coarse_entities(entities): |
|
""" |
|
This function removes entities that are not useful for the NEL process. |
|
""" |
|
|
|
useful_types = { |
|
"pers", |
|
"loc", |
|
"org", |
|
"date", |
|
"time", |
|
} |
|
|
|
|
|
cleaned_entities = [ |
|
entity |
|
for entity in entities |
|
if entity["type"] in useful_types or "comp" in entity["type"] |
|
] |
|
|
|
return cleaned_entities |
|
|
|
|
|
def postprocess_entities(entities): |
|
|
|
entity_map = {} |
|
|
|
|
|
for entity in entities: |
|
entity_text = entity["surface"] |
|
num_dots = entity["type"].count(".") |
|
|
|
|
|
if ( |
|
entity_text not in entity_map |
|
or entity_map[entity_text]["type"].count(".") < num_dots |
|
): |
|
entity_map[entity_text] = entity |
|
|
|
|
|
filtered_entities = list(entity_map.values()) |
|
|
|
|
|
filtered_entities = attach_comp_to_closest(filtered_entities) |
|
if DEBUG: |
|
print("After attach_comp_to_closest:", filtered_entities, "\n") |
|
filtered_entities = repair_names_in_entities(filtered_entities) |
|
if DEBUG: |
|
print("After repair_names_in_entities:", filtered_entities, "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
return filtered_entities |
|
|
|
|
|
def remove_included_entities(entities): |
|
|
|
final_entities = [] |
|
for i, entity in enumerate(entities): |
|
is_included = False |
|
for other_entity in entities: |
|
if entity["surface"] != other_entity["surface"]: |
|
if "comp" in other_entity["type"]: |
|
|
|
if entity["surface"] in other_entity["surface"]: |
|
is_included = True |
|
break |
|
elif ( |
|
entity["type"].split(".")[0] in other_entity["type"].split(".")[0] |
|
or other_entity["type"].split(".")[0] |
|
in entity["type"].split(".")[0] |
|
): |
|
if entity["surface"] in other_entity["surface"]: |
|
is_included = True |
|
if not is_included: |
|
final_entities.append(entity) |
|
return final_entities |
|
|
|
|
|
def refine_entities_with_coarse(all_entities, coarse_entities): |
|
""" |
|
Looks through all entities and refines them based on the coarse entities. |
|
If a surface match is found in the coarse entities and the types match, |
|
the entity's confidence_ner and type are updated based on the coarse entity. |
|
""" |
|
|
|
coarse_lookup = {} |
|
for coarse_entity in coarse_entities: |
|
key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0]) |
|
coarse_lookup[key] = coarse_entity |
|
|
|
|
|
for entity in all_entities: |
|
key = ( |
|
entity["surface"], |
|
entity["type"].split(".")[0], |
|
) |
|
|
|
if key in coarse_lookup: |
|
coarse_entity = coarse_lookup[key] |
|
|
|
if entity["confidence_ner"] < coarse_entity["confidence_ner"]: |
|
entity["confidence_ner"] = coarse_entity["confidence_ner"] |
|
entity["type"] = coarse_entity[ |
|
"type" |
|
] |
|
|
|
|
|
for entity in all_entities: |
|
entity["type"] = entity["type"].split(".")[0] |
|
return all_entities |
|
|
|
|
|
def remove_trailing_stopwords(entities): |
|
""" |
|
This function removes stopwords and punctuation from both the beginning and end of each entity's text |
|
and repairs the lOffset and rOffset accordingly. |
|
""" |
|
if DEBUG: |
|
print(f"Initial entities: {len(entities)}") |
|
new_entities = [] |
|
for entity in entities: |
|
if "comp" not in entity["type"]: |
|
entity_text = entity["surface"] |
|
original_len = len(entity_text) |
|
|
|
|
|
lOffset = entity.get("lOffset", 0) |
|
rOffset = entity.get("rOffset", original_len) |
|
|
|
|
|
i = 0 |
|
while entity_text and ( |
|
entity_text.split()[0].lower() in stop_words |
|
or entity_text[0] in punctuation |
|
): |
|
if entity_text.split()[0].lower() in stop_words: |
|
stopword_len = ( |
|
len(entity_text.split()[0]) + 1 |
|
) |
|
entity_text = entity_text[stopword_len:] |
|
lOffset += stopword_len |
|
if DEBUG: |
|
print( |
|
f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}" |
|
) |
|
elif entity_text[0] in punctuation: |
|
entity_text = entity_text[1:] |
|
lOffset += 1 |
|
if DEBUG: |
|
print( |
|
f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}" |
|
) |
|
i += 1 |
|
|
|
i = 0 |
|
|
|
iteration = 0 |
|
max_iterations = len(entity_text) |
|
|
|
while entity_text and iteration < max_iterations: |
|
|
|
last_word = entity_text.split()[-1] if entity_text.split() else "" |
|
last_char = entity_text[-1] |
|
|
|
if last_word.lower() in stop_words: |
|
|
|
stopword_len = len(last_word) + 1 |
|
entity_text = entity_text[:-stopword_len].rstrip() |
|
rOffset -= stopword_len |
|
if DEBUG: |
|
print( |
|
f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})" |
|
) |
|
|
|
elif last_char in punctuation: |
|
|
|
entity_text = entity_text[:-1].rstrip() |
|
rOffset -= 1 |
|
if DEBUG: |
|
print( |
|
f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})" |
|
) |
|
else: |
|
|
|
break |
|
|
|
iteration += 1 |
|
|
|
|
|
if len(entity_text.strip()) == 1: |
|
entities.remove(entity) |
|
if DEBUG: |
|
print(f"Skipping entity: {entity_text}") |
|
continue |
|
|
|
if entity_text in string.punctuation: |
|
if DEBUG: |
|
print(f"Skipping entity: {entity_text}") |
|
entities.remove(entity) |
|
continue |
|
|
|
if entity_text.lower() in stop_words: |
|
if DEBUG: |
|
print(f"Skipping entity: {entity_text}") |
|
entities.remove(entity) |
|
continue |
|
|
|
if all([word.lower() in stop_words for word in entity_text.split()]): |
|
if DEBUG: |
|
print(f"Skipping entity: {entity_text}") |
|
entities.remove(entity) |
|
continue |
|
|
|
if all( |
|
[char.lower() in stop_words for char in entity_text if char.isalpha()] |
|
): |
|
if DEBUG: |
|
print( |
|
f"Skipping entity: {entity_text} (all characters are stopwords)" |
|
) |
|
entities.remove(entity) |
|
continue |
|
|
|
if all([word in string.punctuation for word in entity_text.split()]): |
|
if DEBUG: |
|
print( |
|
f"Skipping entity: {entity_text} (all characters are punctuation)" |
|
) |
|
entities.remove(entity) |
|
continue |
|
if all( |
|
[ |
|
char.lower() in string.punctuation |
|
for char in entity_text |
|
if char.isalpha() |
|
] |
|
): |
|
if DEBUG: |
|
print( |
|
f"Skipping entity: {entity_text} (all characters are punctuation)" |
|
) |
|
entities.remove(entity) |
|
continue |
|
|
|
|
|
if entity_text.isdigit() and "time" not in entity["type"]: |
|
if DEBUG: |
|
print(f"Skipping entity: {entity_text}") |
|
entities.remove(entity) |
|
continue |
|
|
|
if entity_text.startswith(" "): |
|
entity_text = entity_text[1:] |
|
|
|
lOffset += 1 |
|
if entity_text.endswith(" "): |
|
entity_text = entity_text[:-1] |
|
|
|
rOffset -= 1 |
|
|
|
|
|
entity["surface"] = entity_text |
|
entity["lOffset"] = lOffset |
|
entity["rOffset"] = rOffset |
|
|
|
|
|
if len(entity["surface"].strip()) == 0: |
|
if DEBUG: |
|
print(f"Deleted entity: {entity['surface']}") |
|
entities.remove(entity) |
|
else: |
|
new_entities.append(entity) |
|
|
|
if DEBUG: |
|
print(f"Remained entities: {len(new_entities)}") |
|
return new_entities |
|
|
|
class MultitaskTokenClassificationPipeline(Pipeline): |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "text" in kwargs: |
|
preprocess_kwargs["text"] = kwargs["text"] |
|
self.label_map = self.model.config.label_map |
|
self.id2label = { |
|
task: {id_: label for label, id_ in labels.items()} |
|
for task, labels in self.label_map.items() |
|
} |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, text, **kwargs): |
|
|
|
tokenized_inputs = self.tokenizer( |
|
text, padding="max_length", truncation=True, max_length=512 |
|
) |
|
|
|
text_sentence = tokenize(add_spaces_around_punctuation(text)) |
|
return tokenized_inputs, text_sentence, text |
|
|
|
def _forward(self, inputs): |
|
inputs, text_sentences, text = inputs |
|
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to( |
|
self.model.device |
|
) |
|
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to( |
|
self.model.device |
|
) |
|
with torch.no_grad(): |
|
outputs = self.model(input_ids, attention_mask) |
|
return outputs, text_sentences, text |
|
|
|
def is_within(self, entity1, entity2): |
|
"""Check if entity1 is fully within the bounds of entity2.""" |
|
return ( |
|
entity1["lOffset"] >= entity2["lOffset"] |
|
and entity1["rOffset"] <= entity2["rOffset"] |
|
) |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
""" |
|
Postprocess the outputs of the model |
|
:param outputs: |
|
:param kwargs: |
|
:return: |
|
""" |
|
tokens_result, text_sentence, text = outputs |
|
|
|
predictions = {} |
|
confidence_scores = {} |
|
for task, logits in tokens_result.logits.items(): |
|
predictions[task] = torch.argmax(logits, dim=-1).tolist()[0] |
|
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0] |
|
|
|
entities = {} |
|
for task in predictions.keys(): |
|
words_list, preds_list, confidence_list = realign( |
|
text_sentence, |
|
predictions[task], |
|
confidence_scores[task], |
|
self.tokenizer, |
|
self.id2label[task], |
|
) |
|
|
|
entities[task] = get_entities(words_list, preds_list, confidence_list, text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_entities = [] |
|
coarse_entities = [] |
|
for key in entities: |
|
if key in ["NE-COARSE-LIT"]: |
|
coarse_entities = entities[key] |
|
all_entities.extend(entities[key]) |
|
|
|
if DEBUG: |
|
print(all_entities) |
|
|
|
all_entities = remove_included_entities(all_entities) |
|
if DEBUG: |
|
print("After remove_included_entities:", all_entities) |
|
all_entities = remove_trailing_stopwords(all_entities) |
|
if DEBUG: |
|
print("After remove_trailing_stopwords:", all_entities) |
|
all_entities = postprocess_entities(all_entities) |
|
if DEBUG: |
|
print("After postprocess_entities:", all_entities) |
|
all_entities = refine_entities_with_coarse(all_entities, coarse_entities) |
|
if DEBUG: |
|
print("After refine_entities_with_coarse:", all_entities) |
|
|
|
|
|
|
|
return all_entities |
|
|