emanuelaboros's picture
Update generic_ner.py
97a98d1 verified
import logging
from transformers import Pipeline
import numpy as np
import torch
import nltk
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
nltk.download("stopwords")
from nltk.chunk import conlltags2tree
from nltk import pos_tag
from nltk.tree import Tree
import torch.nn.functional as F
import re, string
stop_words = set(nltk.corpus.stopwords.words("english"))
DEBUG = False
punctuation = (
string.punctuation
+ "«»—…“”"
+ "—."
+ "–"
+ "’"
+ "‘"
+ "´"
+ "•"
+ "°"
+ "»"
+ "“"
+ "”"
+ "–"
+ "—"
+ "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
)
# List of additional "strange" punctuation marks
# additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
WHITESPACE_RULES = {
"fr": {
"pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
"pct_no_ws_after": ["(", "[", "{"],
"pct_no_ws_before_after": ["'", "-"],
"pct_number": [".", ","],
},
"de": {
"pct_no_ws_before": [
".",
",",
")",
"]",
"}",
"°",
"...",
"?",
"!",
":",
";",
".-",
"%",
],
"pct_no_ws_after": ["(", "[", "{"],
"pct_no_ws_before_after": ["'", "-"],
"pct_number": [".", ","],
},
"other": {
"pct_no_ws_before": [
".",
",",
")",
"]",
"}",
"°",
"...",
"?",
"!",
":",
";",
".-",
"%",
],
"pct_no_ws_after": ["(", "[", "{"],
"pct_no_ws_before_after": ["'", "-"],
"pct_number": [".", ","],
},
}
def tokenize(text: str, language: str = "other") -> list[str]:
"""Apply whitespace rules to the given text and language, separating it into tokens.
Args:
text (str): The input text to separate into a list of tokens.
language (str): Language of the text.
Returns:
list[str]: List of tokens with punctuation as separate tokens.
"""
# text = add_spaces_around_punctuation(text)
if not text:
return []
if language not in WHITESPACE_RULES:
# Default behavior for languages without specific rules:
# tokenize using standard whitespace splitting
language = "other"
wsrules = WHITESPACE_RULES[language]
tokenized_text = []
current_token = ""
for char in text:
if char in wsrules["pct_no_ws_before_after"]:
if current_token:
tokenized_text.append(current_token)
tokenized_text.append(char)
current_token = ""
elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
if current_token:
tokenized_text.append(current_token)
tokenized_text.append(char)
current_token = ""
elif char.isspace():
if current_token:
tokenized_text.append(current_token)
current_token = ""
else:
current_token += char
if current_token:
tokenized_text.append(current_token)
return tokenized_text
def normalize_text(text):
# Remove spaces and tabs for the search but keep newline characters
return re.sub(r"[ \t]+", "", text)
def find_entity_indices(article_text, search_text):
# Normalize texts by removing spaces and tabs
normalized_article = normalize_text(article_text)
normalized_search = normalize_text(search_text)
# Initialize a list to hold all start and end indices
indices = []
# Find all occurrences of the search text in the normalized article text
start_index = 0
while True:
start_index = normalized_article.find(normalized_search, start_index)
if start_index == -1:
break
# Calculate the actual start and end indices in the original article text
original_chars = 0
original_start_index = 0
for i in range(start_index):
while article_text[original_start_index] in (" ", "\t"):
original_start_index += 1
if article_text[original_start_index] not in (" ", "\t", "\n"):
original_chars += 1
original_start_index += 1
original_end_index = original_start_index
search_chars = 0
while search_chars < len(normalized_search):
if article_text[original_end_index] not in (" ", "\t", "\n"):
search_chars += 1
original_end_index += 1 # Increment to include the last character
# Append the found indices to the list
if article_text[original_start_index] == " ":
original_start_index += 1
indices.append((original_start_index, original_end_index))
# Move start_index to the next position to continue searching
start_index += 1
return indices
def get_entities(tokens, tags, confidences, text):
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
pos_tags = [pos for token, pos in pos_tag(tokens)]
for i in range(1, len(tags)):
# If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
ne_tree = conlltags2tree(conlltags)
entities = []
idx: int = 0
already_done = []
for subtree in ne_tree:
# skipping 'O' tags
if isinstance(subtree, Tree):
original_label = subtree.label()
original_string = " ".join([token for token, pos in subtree.leaves()])
for indices in find_entity_indices(text, original_string):
entity_start_position = indices[0]
entity_end_position = indices[1]
if (
"_".join(
[original_label, original_string, str(entity_start_position)]
)
in already_done
):
continue
else:
already_done.append(
"_".join(
[
original_label,
original_string,
str(entity_start_position),
]
)
)
if len(text[entity_start_position:entity_end_position].strip()) < len(
text[entity_start_position:entity_end_position]
):
entity_start_position = (
entity_start_position
+ len(text[entity_start_position:entity_end_position])
- len(text[entity_start_position:entity_end_position].strip())
)
entities.append(
{
"type": original_label,
"confidence_ner": round(
np.average(confidences[idx : idx + len(subtree)]) * 100, 2
),
"index": (idx, idx + len(subtree)),
"surface": text[
entity_start_position:entity_end_position
], # original_string,
"lOffset": entity_start_position,
"rOffset": entity_end_position,
}
)
idx += len(subtree)
# Update the current character position
# We add the length of the original string + 1 (for the space)
else:
token, pos = subtree
# If it's not a named entity, we still need to update the character
# position
idx += 1
return entities
def realign(
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
):
preds_list, words_list, confidence_list = [], [], []
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
for idx, word in enumerate(text_sentence):
beginning_index = word_ids.index(idx)
try:
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
confidence_list.append(max(softmax_scores[beginning_index]))
except Exception as ex: # the sentence was longer then max_length
preds_list.append("O")
confidence_list.append(0.0)
words_list.append(word)
return words_list, preds_list, confidence_list
def add_spaces_around_punctuation(text):
# Add a space before and after all punctuation
all_punctuation = string.punctuation + punctuation
return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
def attach_comp_to_closest(entities):
# Define valid entity types that can receive a "comp.function" or "comp.name" attachment
valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
# Separate "comp.function" and "comp.name" entities from other entities
comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
for comp_entity in comp_entities:
closest_entity = None
min_distance = float("inf")
# Find the closest non-"comp" entity that is valid for attaching
for other_entity in other_entities:
# Calculate distance between the comp entity and the other entity
if comp_entity["lOffset"] > other_entity["rOffset"]:
distance = comp_entity["lOffset"] - other_entity["rOffset"]
elif comp_entity["rOffset"] < other_entity["lOffset"]:
distance = other_entity["lOffset"] - comp_entity["rOffset"]
else:
distance = 0 # They overlap or touch
# Ensure the entity type is valid and check for minimal distance
if (
distance < min_distance
and other_entity["type"].split(".")[0] in valid_entity_types
):
min_distance = distance
closest_entity = other_entity
# Attach the "comp.function" or "comp.name" if a valid entity is found
if closest_entity:
suffix = comp_entity["type"].split(".")[
-1
] # Extract the suffix (e.g., 'name', 'function')
closest_entity[suffix] = comp_entity["surface"] # Attach the text
return other_entities
def conflicting_context(comp_entity, target_entity):
"""
Determines if there is a conflict between the comp_entity and the target entity.
Prevents incorrect name and function attachments by using a rule-based approach.
"""
# Case 1: Check for correct function attachment to person or organization entities
if comp_entity["type"].startswith("comp.function"):
if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
return True # Conflict: Function should only attach to persons or organizations
# Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
if "loc" in target_entity["type"]:
return True # Conflict: comp.* entities should not attach to locations or similar types
return False # No conflict
def extract_name_from_text(text, partial_name):
"""
Extracts the full name from the entity's text based on the partial name.
This function assumes that the full name starts with capitalized letters and does not
include any words that come after the partial name.
"""
# Split the text and partial name into words
words = tokenize(text)
partial_words = partial_name.split()
if DEBUG:
print("text:", text)
if DEBUG:
print("partial_name:", partial_name)
# Find the position of the partial name in the word list
for i, word in enumerate(words):
if DEBUG:
print(words, "---", words[i : i + len(partial_words)])
if words[i : i + len(partial_words)] == partial_words:
# Initialize full name with the partial name
full_name = partial_words[:]
if DEBUG:
print("full_name:", full_name)
# Check previous words and only add capitalized words (skip lowercase words)
j = i - 1
while j >= 0 and words[j][0].isupper():
full_name.insert(0, words[j])
j -= 1
if DEBUG:
print("full_name:", full_name)
# Return only the full name up to the partial name (ignore words after the name)
return " ".join(full_name).strip() # Join the words to form the full name
# If not found, return the original text (as a fallback)
return text.strip()
def repair_names_in_entities(entities):
"""
This function repairs the names in the entities by extracting the full name
from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
"""
for entity in entities:
if "name" in entity and "pers" in entity["type"]:
name = entity["name"]
text = entity["surface"]
# Check if the attached name is part of the entity's text
if name in text:
# Extract the full name from the text by splitting around the attached name
full_name = extract_name_from_text(entity["surface"], name)
entity["name"] = (
full_name # Replace the partial name with the full name
)
# if "name" not in entity:
# entity["name"] = entity["surface"]
return entities
def clean_coarse_entities(entities):
"""
This function removes entities that are not useful for the NEL process.
"""
# Define a set of entity types that are considered useful for NEL
useful_types = {
"pers", # Person
"loc", # Location
"org", # Organization
"date", # Product
"time", # Time
}
# Filter out entities that are not in the useful_types set unless they are comp.* entities
cleaned_entities = [
entity
for entity in entities
if entity["type"] in useful_types or "comp" in entity["type"]
]
return cleaned_entities
def postprocess_entities(entities):
# Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
entity_map = {}
# Loop over the entities and prioritize the one with the most dots
for entity in entities:
entity_text = entity["surface"]
num_dots = entity["type"].count(".")
# If the entity text is new, or this entity has more dots, update the map
if (
entity_text not in entity_map
or entity_map[entity_text]["type"].count(".") < num_dots
):
entity_map[entity_text] = entity
# Collect the filtered entities from the map
filtered_entities = list(entity_map.values())
# Step 2: Attach "comp.function" entities to the closest other entities
filtered_entities = attach_comp_to_closest(filtered_entities)
if DEBUG:
print("After attach_comp_to_closest:", filtered_entities, "\n")
filtered_entities = repair_names_in_entities(filtered_entities)
if DEBUG:
print("After repair_names_in_entities:", filtered_entities, "\n")
# Step 3: Remove entities that are not useful for NEL
# filtered_entities = clean_coarse_entities(filtered_entities)
# filtered_entities = remove_blacklisted_entities(filtered_entities)
return filtered_entities
def remove_included_entities(entities):
# Loop through entities and remove those whose text is included in another with the same label
final_entities = []
for i, entity in enumerate(entities):
is_included = False
for other_entity in entities:
if entity["surface"] != other_entity["surface"]:
if "comp" in other_entity["type"]:
# Check if entity's text is a substring of another entity's text
if entity["surface"] in other_entity["surface"]:
is_included = True
break
elif (
entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
or other_entity["type"].split(".")[0]
in entity["type"].split(".")[0]
):
if entity["surface"] in other_entity["surface"]:
is_included = True
if not is_included:
final_entities.append(entity)
return final_entities
def refine_entities_with_coarse(all_entities, coarse_entities):
"""
Looks through all entities and refines them based on the coarse entities.
If a surface match is found in the coarse entities and the types match,
the entity's confidence_ner and type are updated based on the coarse entity.
"""
# Create a dictionary for coarse entities based on surface and type for quick lookup
coarse_lookup = {}
for coarse_entity in coarse_entities:
key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
coarse_lookup[key] = coarse_entity
# Iterate through all entities and compare with the coarse entities
for entity in all_entities:
key = (
entity["surface"],
entity["type"].split(".")[0],
) # Use the coarse type for comparison
if key in coarse_lookup:
coarse_entity = coarse_lookup[key]
# If a match is found, update the confidence_ner and type in the entity
if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
entity["confidence_ner"] = coarse_entity["confidence_ner"]
entity["type"] = coarse_entity[
"type"
] # Update the type if the confidence is higher
# No need to append to refined_entities, we're modifying in place
for entity in all_entities:
entity["type"] = entity["type"].split(".")[0]
return all_entities
def remove_trailing_stopwords(entities):
"""
This function removes stopwords and punctuation from both the beginning and end of each entity's text
and repairs the lOffset and rOffset accordingly.
"""
if DEBUG:
print(f"Initial entities: {len(entities)}")
new_entities = []
for entity in entities:
if "comp" not in entity["type"]:
entity_text = entity["surface"]
original_len = len(entity_text)
# Initial offsets
lOffset = entity.get("lOffset", 0)
rOffset = entity.get("rOffset", original_len)
# Remove stopwords and punctuation from the beginning
i = 0
while entity_text and (
entity_text.split()[0].lower() in stop_words
or entity_text[0] in punctuation
):
if entity_text.split()[0].lower() in stop_words:
stopword_len = (
len(entity_text.split()[0]) + 1
) # Adjust length for stopword and following space
entity_text = entity_text[stopword_len:] # Remove leading stopword
lOffset += stopword_len # Adjust the left offset
if DEBUG:
print(
f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
)
elif entity_text[0] in punctuation:
entity_text = entity_text[1:] # Remove leading punctuation
lOffset += 1 # Adjust the left offset
if DEBUG:
print(
f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
)
i += 1
i = 0
# Remove stopwords and punctuation from the end
iteration = 0
max_iterations = len(entity_text) # Prevent infinite loops
while entity_text and iteration < max_iterations:
# Check if the last word is a stopword or the last character is punctuation
last_word = entity_text.split()[-1] if entity_text.split() else ""
last_char = entity_text[-1]
if last_word.lower() in stop_words:
# Remove trailing stopword and adjust rOffset
stopword_len = len(last_word) + 1 # Include space before stopword
entity_text = entity_text[:-stopword_len].rstrip()
rOffset -= stopword_len
if DEBUG:
print(
f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})"
)
elif last_char in punctuation:
# Remove trailing punctuation and adjust rOffset
entity_text = entity_text[:-1].rstrip()
rOffset -= 1
if DEBUG:
print(
f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})"
)
else:
# Exit loop if neither stopwords nor punctuation are found
break
iteration += 1
# print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}")
if len(entity_text.strip()) == 1:
entities.remove(entity)
if DEBUG:
print(f"Skipping entity: {entity_text}")
continue
# Skip certain entities based on rules
if entity_text in string.punctuation:
if DEBUG:
print(f"Skipping entity: {entity_text}")
entities.remove(entity)
continue
# check now if its in stopwords
if entity_text.lower() in stop_words:
if DEBUG:
print(f"Skipping entity: {entity_text}")
entities.remove(entity)
continue
# check now if the entire entity is a list of stopwords:
if all([word.lower() in stop_words for word in entity_text.split()]):
if DEBUG:
print(f"Skipping entity: {entity_text}")
entities.remove(entity)
continue
# Check if the entire entity is made up of stopwords characters
if all(
[char.lower() in stop_words for char in entity_text if char.isalpha()]
):
if DEBUG:
print(
f"Skipping entity: {entity_text} (all characters are stopwords)"
)
entities.remove(entity)
continue
# check now if all entity is in a list of punctuation
if all([word in string.punctuation for word in entity_text.split()]):
if DEBUG:
print(
f"Skipping entity: {entity_text} (all characters are punctuation)"
)
entities.remove(entity)
continue
if all(
[
char.lower() in string.punctuation
for char in entity_text
if char.isalpha()
]
):
if DEBUG:
print(
f"Skipping entity: {entity_text} (all characters are punctuation)"
)
entities.remove(entity)
continue
# if it's a number and "time" no in it, then continue
if entity_text.isdigit() and "time" not in entity["type"]:
if DEBUG:
print(f"Skipping entity: {entity_text}")
entities.remove(entity)
continue
if entity_text.startswith(" "):
entity_text = entity_text[1:]
# update lOffset, rOffset
lOffset += 1
if entity_text.endswith(" "):
entity_text = entity_text[:-1]
# update lOffset, rOffset
rOffset -= 1
# Update the entity surface and offsets
entity["surface"] = entity_text
entity["lOffset"] = lOffset
entity["rOffset"] = rOffset
# Remove the entity if the surface is empty after cleaning
if len(entity["surface"].strip()) == 0:
if DEBUG:
print(f"Deleted entity: {entity['surface']}")
entities.remove(entity)
else:
new_entities.append(entity)
if DEBUG:
print(f"Remained entities: {len(new_entities)}")
return new_entities
class MultitaskTokenClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
self.label_map = self.model.config.label_map
self.id2label = {
task: {id_: label for label, id_ in labels.items()}
for task, labels in self.label_map.items()
}
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
tokenized_inputs = self.tokenizer(
text, padding="max_length", truncation=True, max_length=512
)
text_sentence = tokenize(add_spaces_around_punctuation(text))
return tokenized_inputs, text_sentence, text
def _forward(self, inputs):
inputs, text_sentences, text = inputs
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
self.model.device
)
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
self.model.device
)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask)
return outputs, text_sentences, text
def is_within(self, entity1, entity2):
"""Check if entity1 is fully within the bounds of entity2."""
return (
entity1["lOffset"] >= entity2["lOffset"]
and entity1["rOffset"] <= entity2["rOffset"]
)
def postprocess(self, outputs, **kwargs):
"""
Postprocess the outputs of the model
:param outputs:
:param kwargs:
:return:
"""
tokens_result, text_sentence, text = outputs
predictions = {}
confidence_scores = {}
for task, logits in tokens_result.logits.items():
predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
entities = {}
for task in predictions.keys():
words_list, preds_list, confidence_list = realign(
text_sentence,
predictions[task],
confidence_scores[task],
self.tokenizer,
self.id2label[task],
)
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
# add titles to comp entities
# from pprint import pprint
# print("Before:")
# pprint(entities)
all_entities = []
coarse_entities = []
for key in entities:
if key in ["NE-COARSE-LIT"]:
coarse_entities = entities[key]
all_entities.extend(entities[key])
if DEBUG:
print(all_entities)
# print("After remove_included_entities:")
all_entities = remove_included_entities(all_entities)
if DEBUG:
print("After remove_included_entities:", all_entities)
all_entities = remove_trailing_stopwords(all_entities)
if DEBUG:
print("After remove_trailing_stopwords:", all_entities)
all_entities = postprocess_entities(all_entities)
if DEBUG:
print("After postprocess_entities:", all_entities)
all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
if DEBUG:
print("After refine_entities_with_coarse:", all_entities)
# print("After attach_comp_to_closest:")
# pprint(all_entities)
# print("\n")
return all_entities