import logging from transformers import Pipeline import numpy as np import torch import nltk nltk.download("averaged_perceptron_tagger") nltk.download("averaged_perceptron_tagger_eng") nltk.download("stopwords") from nltk.chunk import conlltags2tree from nltk import pos_tag from nltk.tree import Tree import torch.nn.functional as F import re, string stop_words = set(nltk.corpus.stopwords.words("english")) DEBUG = False punctuation = ( string.punctuation + "«»—…“”" + "—." + "–" + "’" + "‘" + "´" + "•" + "°" + "»" + "“" + "”" + "–" + "—" + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉" ) # List of additional "strange" punctuation marks # additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉" WHITESPACE_RULES = { "fr": { "pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"], "pct_no_ws_after": ["(", "[", "{"], "pct_no_ws_before_after": ["'", "-"], "pct_number": [".", ","], }, "de": { "pct_no_ws_before": [ ".", ",", ")", "]", "}", "°", "...", "?", "!", ":", ";", ".-", "%", ], "pct_no_ws_after": ["(", "[", "{"], "pct_no_ws_before_after": ["'", "-"], "pct_number": [".", ","], }, "other": { "pct_no_ws_before": [ ".", ",", ")", "]", "}", "°", "...", "?", "!", ":", ";", ".-", "%", ], "pct_no_ws_after": ["(", "[", "{"], "pct_no_ws_before_after": ["'", "-"], "pct_number": [".", ","], }, } def tokenize(text: str, language: str = "other") -> list[str]: """Apply whitespace rules to the given text and language, separating it into tokens. Args: text (str): The input text to separate into a list of tokens. language (str): Language of the text. Returns: list[str]: List of tokens with punctuation as separate tokens. """ # text = add_spaces_around_punctuation(text) if not text: return [] if language not in WHITESPACE_RULES: # Default behavior for languages without specific rules: # tokenize using standard whitespace splitting language = "other" wsrules = WHITESPACE_RULES[language] tokenized_text = [] current_token = "" for char in text: if char in wsrules["pct_no_ws_before_after"]: if current_token: tokenized_text.append(current_token) tokenized_text.append(char) current_token = "" elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]: if current_token: tokenized_text.append(current_token) tokenized_text.append(char) current_token = "" elif char.isspace(): if current_token: tokenized_text.append(current_token) current_token = "" else: current_token += char if current_token: tokenized_text.append(current_token) return tokenized_text def normalize_text(text): # Remove spaces and tabs for the search but keep newline characters return re.sub(r"[ \t]+", "", text) def find_entity_indices(article_text, search_text): # Normalize texts by removing spaces and tabs normalized_article = normalize_text(article_text) normalized_search = normalize_text(search_text) # Initialize a list to hold all start and end indices indices = [] # Find all occurrences of the search text in the normalized article text start_index = 0 while True: start_index = normalized_article.find(normalized_search, start_index) if start_index == -1: break # Calculate the actual start and end indices in the original article text original_chars = 0 original_start_index = 0 for i in range(start_index): while article_text[original_start_index] in (" ", "\t"): original_start_index += 1 if article_text[original_start_index] not in (" ", "\t", "\n"): original_chars += 1 original_start_index += 1 original_end_index = original_start_index search_chars = 0 while search_chars < len(normalized_search): if article_text[original_end_index] not in (" ", "\t", "\n"): search_chars += 1 original_end_index += 1 # Increment to include the last character # Append the found indices to the list if article_text[original_start_index] == " ": original_start_index += 1 indices.append((original_start_index, original_end_index)) # Move start_index to the next position to continue searching start_index += 1 return indices def get_entities(tokens, tags, confidences, text): tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags] pos_tags = [pos for token, pos in pos_tag(tokens)] for i in range(1, len(tags)): # If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-' if tags[i].startswith("B-") and tags[i - 1].startswith("I-"): tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)] ne_tree = conlltags2tree(conlltags) entities = [] idx: int = 0 already_done = [] for subtree in ne_tree: # skipping 'O' tags if isinstance(subtree, Tree): original_label = subtree.label() original_string = " ".join([token for token, pos in subtree.leaves()]) for indices in find_entity_indices(text, original_string): entity_start_position = indices[0] entity_end_position = indices[1] if ( "_".join( [original_label, original_string, str(entity_start_position)] ) in already_done ): continue else: already_done.append( "_".join( [ original_label, original_string, str(entity_start_position), ] ) ) if len(text[entity_start_position:entity_end_position].strip()) < len( text[entity_start_position:entity_end_position] ): entity_start_position = ( entity_start_position + len(text[entity_start_position:entity_end_position]) - len(text[entity_start_position:entity_end_position].strip()) ) entities.append( { "type": original_label, "confidence_ner": round( np.average(confidences[idx : idx + len(subtree)]) * 100, 2 ), "index": (idx, idx + len(subtree)), "surface": text[ entity_start_position:entity_end_position ], # original_string, "lOffset": entity_start_position, "rOffset": entity_end_position, } ) idx += len(subtree) # Update the current character position # We add the length of the original string + 1 (for the space) else: token, pos = subtree # If it's not a named entity, we still need to update the character # position idx += 1 return entities def realign( text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map ): preds_list, words_list, confidence_list = [], [], [] word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids() for idx, word in enumerate(text_sentence): beginning_index = word_ids.index(idx) try: preds_list.append(reverted_label_map[out_label_preds[beginning_index]]) confidence_list.append(max(softmax_scores[beginning_index])) except Exception as ex: # the sentence was longer then max_length preds_list.append("O") confidence_list.append(0.0) words_list.append(word) return words_list, preds_list, confidence_list def add_spaces_around_punctuation(text): # Add a space before and after all punctuation all_punctuation = string.punctuation + punctuation return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text) def attach_comp_to_closest(entities): # Define valid entity types that can receive a "comp.function" or "comp.name" attachment valid_entity_types = {"org", "pers", "org.ent", "pers.ind"} # Separate "comp.function" and "comp.name" entities from other entities comp_entities = [ent for ent in entities if ent["type"].startswith("comp")] other_entities = [ent for ent in entities if not ent["type"].startswith("comp")] for comp_entity in comp_entities: closest_entity = None min_distance = float("inf") # Find the closest non-"comp" entity that is valid for attaching for other_entity in other_entities: # Calculate distance between the comp entity and the other entity if comp_entity["lOffset"] > other_entity["rOffset"]: distance = comp_entity["lOffset"] - other_entity["rOffset"] elif comp_entity["rOffset"] < other_entity["lOffset"]: distance = other_entity["lOffset"] - comp_entity["rOffset"] else: distance = 0 # They overlap or touch # Ensure the entity type is valid and check for minimal distance if ( distance < min_distance and other_entity["type"].split(".")[0] in valid_entity_types ): min_distance = distance closest_entity = other_entity # Attach the "comp.function" or "comp.name" if a valid entity is found if closest_entity: suffix = comp_entity["type"].split(".")[ -1 ] # Extract the suffix (e.g., 'name', 'function') closest_entity[suffix] = comp_entity["surface"] # Attach the text return other_entities def conflicting_context(comp_entity, target_entity): """ Determines if there is a conflict between the comp_entity and the target entity. Prevents incorrect name and function attachments by using a rule-based approach. """ # Case 1: Check for correct function attachment to person or organization entities if comp_entity["type"].startswith("comp.function"): if not ("pers" in target_entity["type"] or "org" in target_entity["type"]): return True # Conflict: Function should only attach to persons or organizations # Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations) if "loc" in target_entity["type"]: return True # Conflict: comp.* entities should not attach to locations or similar types return False # No conflict def extract_name_from_text(text, partial_name): """ Extracts the full name from the entity's text based on the partial name. This function assumes that the full name starts with capitalized letters and does not include any words that come after the partial name. """ # Split the text and partial name into words words = tokenize(text) partial_words = partial_name.split() if DEBUG: print("text:", text) if DEBUG: print("partial_name:", partial_name) # Find the position of the partial name in the word list for i, word in enumerate(words): if DEBUG: print(words, "---", words[i : i + len(partial_words)]) if words[i : i + len(partial_words)] == partial_words: # Initialize full name with the partial name full_name = partial_words[:] if DEBUG: print("full_name:", full_name) # Check previous words and only add capitalized words (skip lowercase words) j = i - 1 while j >= 0 and words[j][0].isupper(): full_name.insert(0, words[j]) j -= 1 if DEBUG: print("full_name:", full_name) # Return only the full name up to the partial name (ignore words after the name) return " ".join(full_name).strip() # Join the words to form the full name # If not found, return the original text (as a fallback) return text.strip() def repair_names_in_entities(entities): """ This function repairs the names in the entities by extracting the full name from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached. """ for entity in entities: if "name" in entity and "pers" in entity["type"]: name = entity["name"] text = entity["surface"] # Check if the attached name is part of the entity's text if name in text: # Extract the full name from the text by splitting around the attached name full_name = extract_name_from_text(entity["surface"], name) entity["name"] = ( full_name # Replace the partial name with the full name ) # if "name" not in entity: # entity["name"] = entity["surface"] return entities def clean_coarse_entities(entities): """ This function removes entities that are not useful for the NEL process. """ # Define a set of entity types that are considered useful for NEL useful_types = { "pers", # Person "loc", # Location "org", # Organization "date", # Product "time", # Time } # Filter out entities that are not in the useful_types set unless they are comp.* entities cleaned_entities = [ entity for entity in entities if entity["type"] in useful_types or "comp" in entity["type"] ] return cleaned_entities def postprocess_entities(entities): # Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field entity_map = {} # Loop over the entities and prioritize the one with the most dots for entity in entities: entity_text = entity["surface"] num_dots = entity["type"].count(".") # If the entity text is new, or this entity has more dots, update the map if ( entity_text not in entity_map or entity_map[entity_text]["type"].count(".") < num_dots ): entity_map[entity_text] = entity # Collect the filtered entities from the map filtered_entities = list(entity_map.values()) # Step 2: Attach "comp.function" entities to the closest other entities filtered_entities = attach_comp_to_closest(filtered_entities) if DEBUG: print("After attach_comp_to_closest:", filtered_entities, "\n") filtered_entities = repair_names_in_entities(filtered_entities) if DEBUG: print("After repair_names_in_entities:", filtered_entities, "\n") # Step 3: Remove entities that are not useful for NEL # filtered_entities = clean_coarse_entities(filtered_entities) # filtered_entities = remove_blacklisted_entities(filtered_entities) return filtered_entities def remove_included_entities(entities): # Loop through entities and remove those whose text is included in another with the same label final_entities = [] for i, entity in enumerate(entities): is_included = False for other_entity in entities: if entity["surface"] != other_entity["surface"]: if "comp" in other_entity["type"]: # Check if entity's text is a substring of another entity's text if entity["surface"] in other_entity["surface"]: is_included = True break elif ( entity["type"].split(".")[0] in other_entity["type"].split(".")[0] or other_entity["type"].split(".")[0] in entity["type"].split(".")[0] ): if entity["surface"] in other_entity["surface"]: is_included = True if not is_included: final_entities.append(entity) return final_entities def refine_entities_with_coarse(all_entities, coarse_entities): """ Looks through all entities and refines them based on the coarse entities. If a surface match is found in the coarse entities and the types match, the entity's confidence_ner and type are updated based on the coarse entity. """ # Create a dictionary for coarse entities based on surface and type for quick lookup coarse_lookup = {} for coarse_entity in coarse_entities: key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0]) coarse_lookup[key] = coarse_entity # Iterate through all entities and compare with the coarse entities for entity in all_entities: key = ( entity["surface"], entity["type"].split(".")[0], ) # Use the coarse type for comparison if key in coarse_lookup: coarse_entity = coarse_lookup[key] # If a match is found, update the confidence_ner and type in the entity if entity["confidence_ner"] < coarse_entity["confidence_ner"]: entity["confidence_ner"] = coarse_entity["confidence_ner"] entity["type"] = coarse_entity[ "type" ] # Update the type if the confidence is higher # No need to append to refined_entities, we're modifying in place for entity in all_entities: entity["type"] = entity["type"].split(".")[0] return all_entities def remove_trailing_stopwords(entities): """ This function removes stopwords and punctuation from both the beginning and end of each entity's text and repairs the lOffset and rOffset accordingly. """ if DEBUG: print(f"Initial entities: {len(entities)}") new_entities = [] for entity in entities: if "comp" not in entity["type"]: entity_text = entity["surface"] original_len = len(entity_text) # Initial offsets lOffset = entity.get("lOffset", 0) rOffset = entity.get("rOffset", original_len) # Remove stopwords and punctuation from the beginning i = 0 while entity_text and ( entity_text.split()[0].lower() in stop_words or entity_text[0] in punctuation ): if entity_text.split()[0].lower() in stop_words: stopword_len = ( len(entity_text.split()[0]) + 1 ) # Adjust length for stopword and following space entity_text = entity_text[stopword_len:] # Remove leading stopword lOffset += stopword_len # Adjust the left offset if DEBUG: print( f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}" ) elif entity_text[0] in punctuation: entity_text = entity_text[1:] # Remove leading punctuation lOffset += 1 # Adjust the left offset if DEBUG: print( f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}" ) i += 1 i = 0 # Remove stopwords and punctuation from the end iteration = 0 max_iterations = len(entity_text) # Prevent infinite loops while entity_text and iteration < max_iterations: # Check if the last word is a stopword or the last character is punctuation last_word = entity_text.split()[-1] if entity_text.split() else "" last_char = entity_text[-1] if last_word.lower() in stop_words: # Remove trailing stopword and adjust rOffset stopword_len = len(last_word) + 1 # Include space before stopword entity_text = entity_text[:-stopword_len].rstrip() rOffset -= stopword_len if DEBUG: print( f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})" ) elif last_char in punctuation: # Remove trailing punctuation and adjust rOffset entity_text = entity_text[:-1].rstrip() rOffset -= 1 if DEBUG: print( f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})" ) else: # Exit loop if neither stopwords nor punctuation are found break iteration += 1 # print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}") if len(entity_text.strip()) == 1: entities.remove(entity) if DEBUG: print(f"Skipping entity: {entity_text}") continue # Skip certain entities based on rules if entity_text in string.punctuation: if DEBUG: print(f"Skipping entity: {entity_text}") entities.remove(entity) continue # check now if its in stopwords if entity_text.lower() in stop_words: if DEBUG: print(f"Skipping entity: {entity_text}") entities.remove(entity) continue # check now if the entire entity is a list of stopwords: if all([word.lower() in stop_words for word in entity_text.split()]): if DEBUG: print(f"Skipping entity: {entity_text}") entities.remove(entity) continue # Check if the entire entity is made up of stopwords characters if all( [char.lower() in stop_words for char in entity_text if char.isalpha()] ): if DEBUG: print( f"Skipping entity: {entity_text} (all characters are stopwords)" ) entities.remove(entity) continue # check now if all entity is in a list of punctuation if all([word in string.punctuation for word in entity_text.split()]): if DEBUG: print( f"Skipping entity: {entity_text} (all characters are punctuation)" ) entities.remove(entity) continue if all( [ char.lower() in string.punctuation for char in entity_text if char.isalpha() ] ): if DEBUG: print( f"Skipping entity: {entity_text} (all characters are punctuation)" ) entities.remove(entity) continue # if it's a number and "time" no in it, then continue if entity_text.isdigit() and "time" not in entity["type"]: if DEBUG: print(f"Skipping entity: {entity_text}") entities.remove(entity) continue if entity_text.startswith(" "): entity_text = entity_text[1:] # update lOffset, rOffset lOffset += 1 if entity_text.endswith(" "): entity_text = entity_text[:-1] # update lOffset, rOffset rOffset -= 1 # Update the entity surface and offsets entity["surface"] = entity_text entity["lOffset"] = lOffset entity["rOffset"] = rOffset # Remove the entity if the surface is empty after cleaning if len(entity["surface"].strip()) == 0: if DEBUG: print(f"Deleted entity: {entity['surface']}") entities.remove(entity) else: new_entities.append(entity) if DEBUG: print(f"Remained entities: {len(new_entities)}") return new_entities class MultitaskTokenClassificationPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "text" in kwargs: preprocess_kwargs["text"] = kwargs["text"] self.label_map = self.model.config.label_map self.id2label = { task: {id_: label for label, id_ in labels.items()} for task, labels in self.label_map.items() } return preprocess_kwargs, {}, {} def preprocess(self, text, **kwargs): tokenized_inputs = self.tokenizer( text, padding="max_length", truncation=True, max_length=512 ) text_sentence = tokenize(add_spaces_around_punctuation(text)) return tokenized_inputs, text_sentence, text def _forward(self, inputs): inputs, text_sentences, text = inputs input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to( self.model.device ) attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to( self.model.device ) with torch.no_grad(): outputs = self.model(input_ids, attention_mask) return outputs, text_sentences, text def is_within(self, entity1, entity2): """Check if entity1 is fully within the bounds of entity2.""" return ( entity1["lOffset"] >= entity2["lOffset"] and entity1["rOffset"] <= entity2["rOffset"] ) def postprocess(self, outputs, **kwargs): """ Postprocess the outputs of the model :param outputs: :param kwargs: :return: """ tokens_result, text_sentence, text = outputs predictions = {} confidence_scores = {} for task, logits in tokens_result.logits.items(): predictions[task] = torch.argmax(logits, dim=-1).tolist()[0] confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0] entities = {} for task in predictions.keys(): words_list, preds_list, confidence_list = realign( text_sentence, predictions[task], confidence_scores[task], self.tokenizer, self.id2label[task], ) entities[task] = get_entities(words_list, preds_list, confidence_list, text) # add titles to comp entities # from pprint import pprint # print("Before:") # pprint(entities) all_entities = [] coarse_entities = [] for key in entities: if key in ["NE-COARSE-LIT"]: coarse_entities = entities[key] all_entities.extend(entities[key]) if DEBUG: print(all_entities) # print("After remove_included_entities:") all_entities = remove_included_entities(all_entities) if DEBUG: print("After remove_included_entities:", all_entities) all_entities = remove_trailing_stopwords(all_entities) if DEBUG: print("After remove_trailing_stopwords:", all_entities) all_entities = postprocess_entities(all_entities) if DEBUG: print("After postprocess_entities:", all_entities) all_entities = refine_entities_with_coarse(all_entities, coarse_entities) if DEBUG: print("After refine_entities_with_coarse:", all_entities) # print("After attach_comp_to_closest:") # pprint(all_entities) # print("\n") return all_entities