Zamanonymize3 / utils_demo.py
jfrery-zama's picture
initial commit
646bd9e
raw
history blame
No virus
9.99 kB
import logging
import re
import string
from flair.data import Sentence
from flair.models import SequenceTagger
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
entity_label_to_code_map = {'<PERSON>': 0,
'<O>': 1,
'<MISC>-<NRP>': 2,
'<NUMBER>': 3,
'<PER>-<LOCATION>': 4,
'<LOC>': 5,
'<MISC>': 6, # Miscellaneous: doesn't fall into the more common categories of PERSON, LOCATION, ORGANIZATION,
'<DATE_TIME>': 7,
'<LOCATION>': 8,
'<PRONOUNS>': 9,
'<IN_PAN>': 10,
'<MISC>-<DATE_TIME>': 11,
'<ORG>': 12,
'<MISC>-<IN_PAN>': 13,
'<MISC>-<LOCATION>': 14,
'<PER>': 15,
'<MISC>-<PERSON>': 16,
'<LOC>-<PERSON>': 17,
'<PHONE_NUMBER>': 18,
'<LOC>-<DATE_TIME>': 19,
'<LOC>-<NRP>': 20,
'<NRP>': 21,
'<ORG>-<PERSON>': 22,
'<PER>-<NRP>': 23,
'<ORG>-<LOCATION>': 24,
'<PER>-<DATE_TIME>': 25,
'<PER>-<IN_PAN>': 26,
'<ORG>-<IN_PAN>': 27,
'<ORG>-<NRP>': 28,
'<US_DRIVER_LICENSE>': 29,
'<KEY <EMAIL_ADDRESS>': 30,
'<US_BANK_NUMBER>': 33,
'<IN_AADHAAR>': 34,
'<CRYPTO>': 35,
'<IP_ADDRESS>': 36,
'<EMAIL_ADDRESS>': 35,
'<US_PASSPORT>': 36,
'<US_SSN>': 37,
'<MISC>-<URL>': 38}
pronoun_list = [
'I', 'i', 'me', 'my', 'mine', 'myself', 'you', 'your', 'yours', "I'm", "I am",\
'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "i'm", \
'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', \
'their', 'theirs', 'themselves', 'we', 'us', 'our', 'ours', 'ourselves' \
'Me', 'My', 'Mine', 'Myself', 'You', 'Your', 'Yours', 'Yourself', 'Yourselves', \
'He', 'Him', 'His', 'Himself', 'She', 'Her', 'Hers', 'Herself', 'It', 'Its', 'Itself', \
'They', 'Them', 'Their', 'Theirs', 'Themselves', 'We', 'Us', 'Our', 'Ours', 'Ourselves',
"Lady", "Madam", "Mr.", "Mister", "Sir", "Miss", "Ms.", "Mrs.", "Mr"
]
privacy_category_codes = {'<PRIVATE>': 1, '<NON_PRIVATE>': 2, '<OTHER>': 3}
punctuation_list = list(string.punctuation)
punctuation_list.remove('%')
punctuation_list.remove('$')
punctuation_list = ''.join(punctuation_list)
def get_word_boundaries(sentence):
""" Find the start and end positions of each word in a sentence."""
return [(match.start(), match.end()) for match in re.finditer(r'[^\s]+', sentence)]
def fuse_ner_labels(flair_ner, presidio_ner, text_type="<PRIVATE>"):
"""Merges The NER labels from 'Flair' and 'Presidio' for a given text.
We add take into account custom cases and predefined rules for entity classification.
"""
merged_ner = []
# Sanity check
assert len(flair_ner) == len(presidio_ner)
for i, ((w1, n1), (w2, n2)) in enumerate(zip(presidio_ner, flair_ner)):
assert w1 == w2
if w1.lower() in pronoun_list:
common_ner = "<PRONOUNS>"
# elif w1 in ['A+', 'A-', 'B+', 'B-', 'AB+', 'AB-', 'O+', 'O-']:
# common_ner = "<PRIVATE>"
elif n1 == "<O>" and n2 == "<O>":
if w1.lower() in ["am", "'m"] and (i - 1) >= 0 and presidio_ner[i - 1][0].lower() == 'i':
common_ner = "<PRONOUNS>"
elif bool(re.match(r'(?<!\S)[\$€]?(?:\d{1,3}(?:[ ,.]\d{3})*|\d+)(?:\.\d+)?%?', w1)):
common_ner = "<NUMBER>"
else:
common_ner = '<O>'
elif n1 in n2:
common_ner = n2
elif n1 == '<O>' and n2 != '<O>':
common_ner = n2
elif n2 == '<O>' and n1 != '<O>':
common_ner = f"<{n1}>"
else:
common_ner = f"<{n1}>-{n2}"
try:
common_binary_label = 0 if common_ner =="<O>" else 1
except:
print(f"ERROR: common_binary_label = 0 if common_ner =='<O>' else 1 | {w1=}, {w2=}, {n1=}, {n2=}")
if common_ner not in entity_label_to_code_map.keys():
common_multi_label = len(entity_label_to_code_map)
if common_ner not in entity_label_to_code_map.keys():
print("NOT in KEY", common_ner)
entity_label_to_code_map[common_ner] = common_multi_label
else:
common_multi_label = entity_label_to_code_map[common_ner]
is_private = text_type if common_ner != '<O>' else '<OTHER>'
merged_ner.append([w1, common_ner, is_private, privacy_category_codes[is_private], common_binary_label, common_multi_label])
return merged_ner
analyzer = AnalyzerEngine()
anonymizer = AnonymizerEngine()
def apply_presidio_model(sentence, verbose=True):
"""Get Presidio predictions."""
if verbose: print(f"{sentence=}")
# anonymized_text looks like: ['<PERSON>', 'went', 'to', 'Pitier', 'Hospital', ...]
anonymized_text = anonymizer.anonymize(text=sentence, analyzer_results=analyzer.analyze(text=sentence, language='en'))
anonymized_text = anonymized_text.__dict__['text'].split()
anonymized_text = ' '.join(anonymized_text)
next_word_to_concate = None
if verbose: print(f"{anonymized_text=}")
if verbose: print(f"{anonymized_text.split('<')=}")
start_index, label = 0, []
previous_label = None
for i, before_split in enumerate(anonymized_text.split('<')):
if verbose:
print(f"\nSubseq_{i}: {before_split=}")
if i == 0:
assert len(before_split) == len(sentence[start_index: len(before_split)])
start_index = len(before_split)
label.extend([(s, '<O>') for s in before_split.split()])
else:
after_split = before_split.split(">")
if verbose:
print(f" -----> ", after_split)
print(sentence[start_index:])
print(sentence[start_index:].find(after_split[-1]))
start2_index = start_index + sentence[start_index:].find(after_split[-1])
end2_index = start2_index + len(after_split[-1])
if verbose:
print(f"Sanity check: '[{sentence[start2_index: end2_index]}]' VS '[{after_split[-1]}]'")
print(f"Hidden part: sentence[{start2_index}: {end2_index}] = {sentence[start2_index: end2_index]}")
assert sentence[start2_index: end2_index] == after_split[-1]
start2_index = start2_index if start2_index != start_index else len(sentence)
for j, anonimyzed_word in enumerate((sentence[start_index: start2_index]).split()):
if next_word_to_concate != None and j == 0:
label.append((f"{next_word_to_concate}{anonimyzed_word}", f"<{after_split[0]}>"))
next_word_to_concate = None
else:
label.append((anonimyzed_word, f"<{after_split[0]}>"))
previous_label = f"<{after_split[0]}>"
if len(sentence[start2_index: end2_index]) >= 1 and after_split[-1][-1] != ' ' and i != len(anonymized_text.split('<')) - 1:
if verbose: print("Is there a space after?", after_split, after_split[-1][-1], i, len(anonymized_text.split('<')))
for j, anonimyzed_word in enumerate((after_split[-1]).split()[:-1]):
label.append((anonimyzed_word, "<O>"))
next_word_to_concate = (after_split[-1]).split()[-1]
elif len(sentence[start2_index: end2_index]) >= 1 and after_split[-1][0] != ' ' and i != len(anonymized_text.split('<')) - 1:
if verbose: print("Is there a space before?", after_split, after_split[-1][0], i, len(anonymized_text.split('<')))
label[-1] = (f"{label[-1][0]}{after_split[-1].split()[0]}", previous_label)
for j, anonimyzed_word in enumerate((after_split[-1]).split()[1:]):
label.append((anonimyzed_word, "<O>"))
else:
for j, anonimyzed_word in enumerate((after_split[-1]).split()):
label.append((anonimyzed_word, "<O>"))
start_index = end2_index
return label
def apply_flair_model(original_sentence):
"""Get Flair predictions."""
logging.getLogger('flair').setLevel(logging.WARNING)
tagger = SequenceTagger.load("flair/ner-english-large")
flair_sentence = Sentence(original_sentence)
tagger.predict(flair_sentence)
word_boundaries = get_word_boundaries(original_sentence)
ner = [[i_token.form, \
b_token.get_label().value, \
i_token.get_label().score, \
i_token.start_position, \
i_token.end_position] for b_token in flair_sentence.get_spans("ner") for i_token in b_token]
ner_labels, ner_index = [], 0
for start, end in word_boundaries:
word_from_text = original_sentence[start:end]
if ner_index < len(ner):
form, label, _, s, e = ner[ner_index]
if (s, e) == (start, end) and word_from_text == form:
ner_labels.append((word_from_text, label))
ner_index += 1
else:
ner_labels.append((word_from_text, "<O>"))
else:
ner_labels.append((word_from_text, "<O>"))
assert len(ner_labels) == len(word_boundaries)
return ner_labels
def preprocess_sentences(sentence, verbose=False):
"""Preprocess the sentence."""
# Removing Extra Newlines:
sentence = re.sub(r'\n+', ' ', sentence)
if verbose: print(sentence)
# Collapsing Multiple Spaces:
sentence = re.sub(' +', ' ', sentence)
if verbose: print(sentence)
# Handling Apostrophes in Possessives:
sentence = re.sub(r"'s\b", " s", sentence)
if verbose: print(sentence)
# Removing Spaces Before Punctuation:
sentence = re.sub(r'\s([,.!?;:])', r'\1', sentence)
if verbose: print(sentence)
# Pattern for Matching Leading or Trailing Punctuation:
pattern = r'(?<!\w)[{}]|[{}](?!\w)'.format(re.escape(punctuation_list), re.escape(punctuation_list))
sentence = re.sub(pattern, '', sentence)
if verbose: print(sentence)
return sentence