|
import re |
|
|
|
import numpy as np |
|
|
|
import fasttext |
|
|
|
import sentencepiece |
|
import kenlm |
|
|
|
import pathlib |
|
|
|
from languages_id import langs_id |
|
from parameters_filtering import parameters_filtering |
|
from normalization import normalization |
|
from stopwords import stopwords |
|
from badwords import badwords |
|
|
|
|
|
class LoadParameters: |
|
@staticmethod |
|
def load_parameters(lang_dataset_id): |
|
if lang_dataset_id in parameters_filtering: |
|
param = parameters_filtering[lang_dataset_id] |
|
else: |
|
param = parameters_filtering["default"] |
|
return param |
|
|
|
@staticmethod |
|
def load_stopwords(lang_dataset_id): |
|
stopwords_lang_id = langs_id.loc[ |
|
langs_id["dataset_id"] == lang_dataset_id, "stopwords_id" |
|
].iloc[0] |
|
if stopwords_lang_id: |
|
stopwords_lang = set(stopwords[stopwords_lang_id]) |
|
else: |
|
stopwords_lang = None |
|
return stopwords_lang |
|
|
|
@staticmethod |
|
def load_badwords(lang_dataset_id): |
|
badwords_lang_id = langs_id.loc[ |
|
langs_id["dataset_id"] == lang_dataset_id, "badwords_id" |
|
].iloc[0] |
|
if badwords_lang_id: |
|
badwords_lang = set(badwords[badwords_lang_id]) |
|
else: |
|
badwords_lang = None |
|
return badwords_lang |
|
|
|
@staticmethod |
|
def load_model_lang_id(lang_dataset_id, path_fasttext_model): |
|
fasttext_lang_id = langs_id.loc[ |
|
langs_id["dataset_id"] == lang_dataset_id, "fasttext_id" |
|
].iloc[0] |
|
if fasttext_lang_id: |
|
model_lang_id = fasttext.load_model(path_fasttext_model) |
|
else: |
|
model_lang_id = None |
|
return model_lang_id |
|
|
|
@staticmethod |
|
def load_sentencepiece_model(lang_dataset_id, path_sentencepiece_model): |
|
sentencepiece_lang_id = langs_id.loc[ |
|
langs_id["dataset_id"] == lang_dataset_id, "sentencepiece_id" |
|
].iloc[0] |
|
if sentencepiece_lang_id: |
|
sentencepiece_model = sentencepiece.SentencePieceProcessor() |
|
sentencepiece_model.load(path_sentencepiece_model) |
|
else: |
|
sentencepiece_model = None |
|
return sentencepiece_model |
|
|
|
@staticmethod |
|
def load_kenlm_model(lang_dataset_id, path_kenlm_model): |
|
kenlm_lang_id = langs_id.loc[ |
|
langs_id["dataset_id"] == lang_dataset_id, "kenlm_id" |
|
].iloc[0] |
|
if kenlm_lang_id: |
|
kenlm_model = kenlm.Model(path_kenlm_model) |
|
else: |
|
kenlm_model = None |
|
return kenlm_model |
|
|
|
|
|
class ModifyingDocuments: |
|
@staticmethod |
|
def remove_empty_el_from_list(list_): |
|
return [el for el in list_ if el] |
|
|
|
@staticmethod |
|
def remove_non_printing_characters(document, non_printing_characters_re): |
|
return non_printing_characters_re.sub("", document) |
|
|
|
@staticmethod |
|
def uniform_whitespace( |
|
document, |
|
whitespace=[ |
|
" ", |
|
"β", |
|
"β", |
|
"β―", |
|
"β
", |
|
"γ", |
|
"β", |
|
"Β ", |
|
"β", |
|
"β", |
|
"οΏΌ", |
|
"Β", |
|
], |
|
): |
|
"""There are different whitespace characters.""" |
|
whitespace = set(whitespace) |
|
document = "".join( |
|
[char if char not in whitespace else " " for char in document] |
|
) |
|
return document |
|
|
|
@staticmethod |
|
def replace_digits_with_zeros(document, digits_re): |
|
return digits_re.sub("0", document) |
|
|
|
@staticmethod |
|
def replace_unicode_punctuation(document, unicode_punctuation): |
|
return "".join(unicode_punctuation.get(c, c) for c in document) |
|
|
|
@staticmethod |
|
def normalization( |
|
document, |
|
remove_non_printing_characters, |
|
strip, |
|
lower_case, |
|
uniform_whitespace, |
|
replace_digits_with_zeros, |
|
replace_unicode_punctuation, |
|
non_printing_characters_re=normalization["non_printing_characters_re"], |
|
digits_re=normalization["digits_re"], |
|
unicode_punctuation=normalization["unicode_punctuation"], |
|
): |
|
if remove_non_printing_characters: |
|
document = ModifyingDocuments.remove_non_printing_characters( |
|
document, non_printing_characters_re |
|
) |
|
if strip: |
|
document = document.strip() |
|
if not document: |
|
return document |
|
if lower_case: |
|
document = document.lower() |
|
if uniform_whitespace: |
|
document = ModifyingDocuments.uniform_whitespace(document) |
|
if replace_digits_with_zeros: |
|
document = ModifyingDocuments.replace_digits_with_zeros(document, digits_re) |
|
if replace_unicode_punctuation: |
|
document = ModifyingDocuments.replace_unicode_punctuation( |
|
document, unicode_punctuation |
|
) |
|
return document |
|
|
|
@staticmethod |
|
def tokenization(document, sentencepiece_model, join_on_whitespace): |
|
document_tokenized = sentencepiece_model.encode_as_pieces(document) |
|
if join_on_whitespace: |
|
document_tokenized = " ".join(document_tokenized) |
|
return document_tokenized |
|
|
|
@staticmethod |
|
def split_on_whitespace( |
|
document, |
|
new_line=False, |
|
tab=False, |
|
): |
|
"""This method also removes concatenated spaces.""" |
|
sep = [" "] + new_line * ["\n"] + tab * ["\t"] |
|
sep = "|".join(sep) |
|
split_document = re.split(sep, document) |
|
split_document = ModifyingDocuments.remove_empty_el_from_list(split_document) |
|
return split_document |
|
|
|
@staticmethod |
|
def strip(document, strip_characters): |
|
"""Way faster than document.strip(strip_characters) |
|
since strip_characters is now a set instead of a str, |
|
and it contains a lot of elements (all the emojis).""" |
|
if not document: |
|
return document |
|
beg_ind = 0 |
|
end_ind = len(document) |
|
for i in range(len(document)): |
|
if document[i] in strip_characters: |
|
beg_ind += 1 |
|
else: |
|
break |
|
for i in range(1, len(document) + 1): |
|
if document[-i] in strip_characters: |
|
end_ind -= 1 |
|
else: |
|
break |
|
document_stripped = document[beg_ind:end_ind] |
|
return document_stripped |
|
|
|
@staticmethod |
|
def get_words_from_document( |
|
document, sentencepiece_model_tok, lower_case, strip_characters |
|
): |
|
"""Get words from a document. Non reversible since the document |
|
is split on multiple characters, words are stripped of |
|
special characters and characters are converted to lower case. |
|
Useful to compute ratios, like the stopwords ratio.""" |
|
if sentencepiece_model_tok: |
|
document_normalized = ModifyingDocuments.normalization( |
|
document=document, |
|
remove_non_printing_characters=True, |
|
strip=True, |
|
lower_case=True, |
|
uniform_whitespace=True, |
|
replace_digits_with_zeros=True, |
|
replace_unicode_punctuation=True, |
|
) |
|
words = ModifyingDocuments.tokenization( |
|
document_normalized, sentencepiece_model_tok, join_on_whitespace=False |
|
) |
|
else: |
|
words = ModifyingDocuments.split_on_whitespace( |
|
document, new_line=True, tab=True |
|
) |
|
if lower_case: |
|
words = [word.lower() for word in words] |
|
if strip_characters: |
|
words = [ModifyingDocuments.strip(word, strip_characters) for word in words] |
|
words = ModifyingDocuments.remove_empty_el_from_list(words) |
|
return words |
|
|
|
@staticmethod |
|
def words_augmentation(words, group_size, join_char): |
|
"""Augment words, especially for Chinese (without a space between words) |
|
and Vietnamese (with a space between syllables).""" |
|
augmentation = [ |
|
join_char.join(words[i : i + group_size]) |
|
for i in range(len(words) - group_size + 1) |
|
] |
|
return augmentation |
|
|
|
@staticmethod |
|
def split_on_newline_tab_whitespace(document): |
|
"""First split on "\n", then on "\t", then on " ".""" |
|
sentences = document.split("\n") |
|
sentences = [sentence.split("\t") for sentence in sentences] |
|
sentences = [ |
|
[ |
|
ModifyingDocuments.split_on_whitespace(subsentence) |
|
for subsentence in sentence |
|
] |
|
for sentence in sentences |
|
] |
|
return sentences |
|
|
|
@staticmethod |
|
def merge_on_whitespace_tab_newline(sentences): |
|
"""Invert the method split_on_newline_tab_whitespace. |
|
Removes concatenated separators.""" |
|
sentences = [ |
|
[" ".join(subsentence) for subsentence in sentence if subsentence] |
|
for sentence in sentences |
|
] |
|
sentences = ["\t".join(sentence) for sentence in sentences if sentence] |
|
if not sentences: |
|
return "" |
|
document = "\n".join(sentences) |
|
return document |
|
|
|
@staticmethod |
|
def should_keep_word_with_incorrect_substrings( |
|
word, strip_characters, incorrect_word_substrings |
|
): |
|
word = ModifyingDocuments.strip(word, strip_characters) |
|
should_keep = all( |
|
[(i_substr not in word) for i_substr in incorrect_word_substrings] |
|
) |
|
return should_keep |
|
|
|
@staticmethod |
|
def remove_words_with_incorrect_substrings( |
|
document, |
|
strip_characters, |
|
incorrect_word_substrings, |
|
): |
|
sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) |
|
sentences = [ |
|
[ |
|
[ |
|
word |
|
for word in subsentence |
|
if ModifyingDocuments.should_keep_word_with_incorrect_substrings( |
|
word, strip_characters, incorrect_word_substrings |
|
) |
|
] |
|
for subsentence in sentence |
|
] |
|
for sentence in sentences |
|
] |
|
document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) |
|
return document |
|
|
|
@staticmethod |
|
def should_keep_long_word(word, strip_characters, length_word_max_cutoff): |
|
"""If the word is too long but it contains only one |
|
special character, it might be a concatenation of one word, |
|
a punctuation, and another word, with no space between them. |
|
In this case, we give the word a pass.""" |
|
if len(word) <= length_word_max_cutoff: |
|
return True |
|
word = ModifyingDocuments.strip(word, strip_characters) |
|
if not word: |
|
return False |
|
if len(word) <= length_word_max_cutoff: |
|
return True |
|
return False |
|
|
|
def remove_long_words( |
|
document, |
|
strip_characters, |
|
length_word_max_cutoff, |
|
): |
|
sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) |
|
sentences = [ |
|
[ |
|
[ |
|
word |
|
for word in subsentence |
|
if ModifyingDocuments.should_keep_long_word( |
|
word, |
|
strip_characters, |
|
length_word_max_cutoff, |
|
) |
|
] |
|
for subsentence in sentence |
|
] |
|
for sentence in sentences |
|
] |
|
document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) |
|
return document |
|
|
|
@staticmethod |
|
def modifying_documents( |
|
document, |
|
cond_uniform_whitespace, |
|
cond_replace_unicode_punctuation, |
|
cond_remove_words_with_incorrect_substrings, |
|
strip_characters, |
|
incorrect_word_substrings, |
|
cond_remove_long_words, |
|
length_word_max_cutoff, |
|
): |
|
document = ModifyingDocuments.normalization( |
|
document=document, |
|
remove_non_printing_characters=False, |
|
strip=True, |
|
lower_case=False, |
|
uniform_whitespace=cond_uniform_whitespace, |
|
replace_digits_with_zeros=False, |
|
replace_unicode_punctuation=cond_replace_unicode_punctuation, |
|
) |
|
if cond_remove_words_with_incorrect_substrings: |
|
document = ModifyingDocuments.remove_words_with_incorrect_substrings( |
|
document, |
|
strip_characters, |
|
incorrect_word_substrings, |
|
) |
|
if cond_remove_long_words: |
|
document = ModifyingDocuments.remove_long_words( |
|
document, |
|
strip_characters, |
|
length_word_max_cutoff, |
|
) |
|
return document |
|
|
|
|
|
class FunctionDatasetModifyingDocuments: |
|
def __init__(self, lang_dataset_id): |
|
self.lang_dataset_id = lang_dataset_id |
|
self.param = LoadParameters.load_parameters(lang_dataset_id) |
|
|
|
def __call__(self, example): |
|
example["text"] = ModifyingDocuments.modifying_documents( |
|
document=example["text"], |
|
cond_uniform_whitespace=self.param["cond_uniform_whitespace"], |
|
cond_replace_unicode_punctuation=self.param[ |
|
"cond_replace_unicode_punctuation" |
|
], |
|
cond_remove_words_with_incorrect_substrings=self.param[ |
|
"cond_remove_words_with_incorrect_substrings" |
|
], |
|
strip_characters=self.param["strip_characters"], |
|
incorrect_word_substrings=self.param["incorrect_word_substrings"], |
|
cond_remove_long_words=self.param["cond_remove_long_words"], |
|
length_word_max_cutoff=self.param["length_word_max_cutoff"], |
|
) |
|
return example |
|
|
|
def __reduce__(self): |
|
return (self.__class__, (self.lang_dataset_id,)) |
|
|
|
|
|
class Filtering: |
|
@staticmethod |
|
def check_number_words( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
number_words_min_cutoff, |
|
number_words_max_cutoff, |
|
): |
|
words = ModifyingDocuments.get_words_from_document( |
|
document, |
|
sentencepiece_model_tok, |
|
lower_case=False, |
|
strip_characters=strip_characters, |
|
) |
|
cond = (len(words) >= number_words_min_cutoff) and ( |
|
len(words) <= number_words_max_cutoff |
|
) |
|
return cond |
|
|
|
@staticmethod |
|
def compute_repetitions_ratio(document, repetitions_length): |
|
def get_freq_ngrams(document, n): |
|
ngrams = [document[i : i + n] for i in range(len(document) - n + 1)] |
|
freq_ngrams = {} |
|
for ngram in ngrams: |
|
freq_ngrams[ngram] = freq_ngrams.get(ngram, 0) + 1 |
|
return freq_ngrams |
|
|
|
freq_ngrams = get_freq_ngrams(document, repetitions_length) |
|
if len(freq_ngrams) == 0: |
|
return 0 |
|
freq_ngrams = list(freq_ngrams.values()) |
|
freq_ngrams = sorted(freq_ngrams, reverse=True) |
|
num_rep_ngrams = int(np.sqrt(len(freq_ngrams))) |
|
repetitions_ratio = sum(freq_ngrams[:num_rep_ngrams]) / sum(freq_ngrams) |
|
return repetitions_ratio |
|
|
|
@staticmethod |
|
def check_repetitions_removal( |
|
document, |
|
repetitions_length, |
|
repetitions_max_cutoff, |
|
): |
|
repetitions_ratio = Filtering.compute_repetitions_ratio( |
|
document, repetitions_length |
|
) |
|
cond = repetitions_ratio <= repetitions_max_cutoff |
|
return cond |
|
|
|
@staticmethod |
|
def compute_special_characters_ratio(document, special_characters): |
|
special_characters_ratio = len( |
|
[char for char in document if char in special_characters] |
|
) / len(document) |
|
return special_characters_ratio |
|
|
|
@staticmethod |
|
def check_special_characters( |
|
document, |
|
special_characters, |
|
special_characters_max_cutoff, |
|
): |
|
special_characters_ratio = Filtering.compute_special_characters_ratio( |
|
document, special_characters |
|
) |
|
cond = special_characters_ratio <= special_characters_max_cutoff |
|
return cond |
|
|
|
@staticmethod |
|
def compute_stopwords_ratio( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
stopwords, |
|
): |
|
words = ModifyingDocuments.get_words_from_document( |
|
document, |
|
sentencepiece_model_tok, |
|
lower_case=True, |
|
strip_characters=strip_characters, |
|
) |
|
if not words: |
|
return 0 |
|
augmentation = [] |
|
if cond_words_augmentation: |
|
augmentation = [ |
|
ModifyingDocuments.words_augmentation( |
|
words, group_size, words_augmentation_join_char |
|
) |
|
for group_size in words_augmentation_group_sizes |
|
] |
|
augmentation = [word for augm in augmentation for word in augm] |
|
stopwords_ratio = len( |
|
[word for word in words + augmentation if word in stopwords] |
|
) / len(words) |
|
if stopwords_ratio > 1.0: |
|
stopwords_ratio = 1.0 |
|
return stopwords_ratio |
|
|
|
@staticmethod |
|
def check_stopwords( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
stopwords, |
|
stopwords_min_cutoff, |
|
): |
|
cond = True |
|
if stopwords: |
|
stopwords_ratio = Filtering.compute_stopwords_ratio( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
stopwords, |
|
) |
|
cond = stopwords_ratio >= stopwords_min_cutoff |
|
return cond |
|
|
|
@staticmethod |
|
def compute_badwords_ratio( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
badwords, |
|
): |
|
words = ModifyingDocuments.get_words_from_document( |
|
document, |
|
sentencepiece_model_tok, |
|
lower_case=True, |
|
strip_characters=strip_characters, |
|
) |
|
if not words: |
|
return 0 |
|
augmentation = [] |
|
if cond_words_augmentation: |
|
augmentation = [ |
|
ModifyingDocuments.words_augmentation( |
|
words, group_size, words_augmentation_join_char |
|
) |
|
for group_size in words_augmentation_group_sizes |
|
] |
|
augmentation = [word for augm in augmentation for word in augm] |
|
badwords_ratio = len( |
|
[word for word in words + augmentation if word in badwords] |
|
) / len(words) |
|
if badwords_ratio > 1.0: |
|
badwords_ratio = 1.0 |
|
for word in augmentation: |
|
if word in badwords: |
|
print(word) |
|
return badwords_ratio |
|
|
|
@staticmethod |
|
def check_badwords( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
badwords, |
|
badwords_max_cutoff, |
|
): |
|
cond = True |
|
if badwords: |
|
badwords_ratio = Filtering.compute_badwords_ratio( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
badwords, |
|
) |
|
cond = badwords_ratio <= badwords_max_cutoff |
|
return cond |
|
|
|
@staticmethod |
|
def compute_lang_id_pred_score(document, model_lang_id): |
|
document = document.lower().replace("\n", " ") |
|
pred = model_lang_id.predict(document) |
|
lang_pred_fasttext_id = pred[0][0].replace("__label__", "") |
|
score_pred = pred[1][0] |
|
lang_pred_dataset_id = langs_id.loc[ |
|
langs_id["fasttext_id"] == lang_pred_fasttext_id, "dataset_id" |
|
] |
|
if len(lang_pred_dataset_id) > 0: |
|
lang_pred_dataset_id = lang_pred_dataset_id.iloc[0] |
|
else: |
|
lang_pred_dataset_id = "unknown" |
|
return lang_pred_dataset_id, score_pred |
|
|
|
@staticmethod |
|
def check_lang_id( |
|
document, |
|
lang_dataset_id, |
|
model_lang_id, |
|
lang_id_min_cutoff, |
|
): |
|
cond = True |
|
if model_lang_id: |
|
lang_pred_dataset_id, score_pred = Filtering.compute_lang_id_pred_score( |
|
document, model_lang_id |
|
) |
|
cond = (lang_pred_dataset_id == lang_dataset_id) and ( |
|
score_pred >= lang_id_min_cutoff |
|
) |
|
return cond |
|
|
|
@staticmethod |
|
def compute_perplexity_score(document, sentencepiece_model, kenlm_model): |
|
document = ModifyingDocuments.normalization( |
|
document=document, |
|
remove_non_printing_characters=True, |
|
strip=True, |
|
lower_case=True, |
|
uniform_whitespace=True, |
|
replace_digits_with_zeros=True, |
|
replace_unicode_punctuation=True, |
|
) |
|
document = ModifyingDocuments.tokenization( |
|
document, sentencepiece_model, join_on_whitespace=True |
|
) |
|
doc_log_score, doc_length = 0, 0 |
|
for line in document.split("\n"): |
|
log_score = kenlm_model.score(line) |
|
length = len(line.split()) + 1 |
|
doc_log_score += log_score |
|
doc_length += length |
|
pp_score = 10.0 ** (-doc_log_score / doc_length) |
|
pp_score = round(pp_score, 1) |
|
return pp_score |
|
|
|
@staticmethod |
|
def check_perplexity( |
|
document, |
|
sentencepiece_model, |
|
kenlm_model, |
|
perplexity_max_cutoff, |
|
): |
|
cond = True |
|
if kenlm_model: |
|
score = Filtering.compute_perplexity_score( |
|
document, sentencepiece_model, kenlm_model |
|
) |
|
cond = score <= perplexity_max_cutoff |
|
return cond |
|
|
|
@staticmethod |
|
def filtering( |
|
document, |
|
cond_check_number_words, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
number_words_min_cutoff, |
|
number_words_max_cutoff, |
|
cond_check_repetitions_removal, |
|
repetitions_length, |
|
repetitions_max_cutoff, |
|
cond_check_special_characters, |
|
special_characters, |
|
special_characters_max_cutoff, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
cond_check_stopwords, |
|
stopwords, |
|
stopwords_min_cutoff, |
|
cond_check_badwords, |
|
badwords, |
|
badwords_max_cutoff, |
|
cond_check_lang_id, |
|
lang_dataset_id, |
|
model_lang_id, |
|
lang_id_min_cutoff, |
|
cond_check_perplexity, |
|
sentencepiece_model, |
|
kenlm_model, |
|
perplexity_max_cutoff, |
|
): |
|
if cond_check_number_words: |
|
if not Filtering.check_number_words( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
number_words_min_cutoff, |
|
number_words_max_cutoff, |
|
): |
|
return False |
|
if cond_check_repetitions_removal: |
|
if not Filtering.check_repetitions_removal( |
|
document, |
|
repetitions_length, |
|
repetitions_max_cutoff, |
|
): |
|
return False |
|
if cond_check_special_characters: |
|
if not Filtering.check_special_characters( |
|
document, |
|
special_characters, |
|
special_characters_max_cutoff, |
|
): |
|
return False |
|
if cond_check_stopwords: |
|
if not Filtering.check_stopwords( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
stopwords, |
|
stopwords_min_cutoff, |
|
): |
|
return False |
|
if cond_check_badwords: |
|
if not Filtering.check_badwords( |
|
document, |
|
sentencepiece_model_tok, |
|
strip_characters, |
|
cond_words_augmentation, |
|
words_augmentation_group_sizes, |
|
words_augmentation_join_char, |
|
badwords, |
|
badwords_max_cutoff, |
|
): |
|
return False |
|
if cond_check_lang_id: |
|
if not Filtering.check_lang_id( |
|
document, |
|
lang_dataset_id, |
|
model_lang_id, |
|
lang_id_min_cutoff, |
|
): |
|
return False |
|
if cond_check_perplexity: |
|
if not Filtering.check_perplexity( |
|
document, |
|
sentencepiece_model, |
|
kenlm_model, |
|
perplexity_max_cutoff, |
|
): |
|
return False |
|
return True |
|
|
|
|
|
class FunctionDatasetFiltering: |
|
def __init__( |
|
self, |
|
lang_dataset_id, |
|
path_fasttext_model, |
|
path_sentencepiece_model, |
|
path_kenlm_model, |
|
): |
|
self.lang_dataset_id = lang_dataset_id |
|
self.path_fasttext_model = path_fasttext_model |
|
self.path_sentencepiece_model = path_sentencepiece_model |
|
self.path_kenlm_model = path_kenlm_model |
|
|
|
self.param = LoadParameters.load_parameters(lang_dataset_id) |
|
self.stopwords = LoadParameters.load_stopwords(lang_dataset_id) |
|
self.badwords = LoadParameters.load_badwords(lang_dataset_id) |
|
self.model_lang_id = LoadParameters.load_model_lang_id( |
|
lang_dataset_id, path_fasttext_model |
|
) |
|
self.sentencepiece_model = LoadParameters.load_sentencepiece_model( |
|
lang_dataset_id, path_sentencepiece_model |
|
) |
|
self.sentencepiece_model_tok = ( |
|
self.sentencepiece_model if self.param["tokenization"] else None |
|
) |
|
self.kenlm_model = LoadParameters.load_kenlm_model( |
|
lang_dataset_id, path_kenlm_model |
|
) |
|
|
|
def __call__(self, example): |
|
keep_example = Filtering.filtering( |
|
document=example["text"], |
|
cond_check_number_words=self.param["cond_check_number_words"], |
|
sentencepiece_model_tok=self.sentencepiece_model_tok, |
|
strip_characters=self.param["strip_characters"], |
|
number_words_min_cutoff=self.param["number_words_min_cutoff"], |
|
number_words_max_cutoff=self.param["number_words_max_cutoff"], |
|
cond_check_repetitions_removal=self.param["check_repetitions_removal"], |
|
repetitions_length=self.param["repetitions_length"], |
|
repetitions_max_cutoff=self.param["repetitions_max_cutoff"], |
|
cond_check_special_characters=self.param["cond_check_special_characters"], |
|
special_characters=self.param["special_characters"], |
|
special_characters_max_cutoff=self.param["special_characters_max_cutoff"], |
|
cond_words_augmentation=self.param["cond_words_augmentation"], |
|
words_augmentation_group_sizes=self.param["words_augmentation_group_sizes"], |
|
words_augmentation_join_char=self.param["words_augmentation_join_char"], |
|
cond_check_stopwords=self.param["cond_check_stopwords"], |
|
stopwords=self.stopwords, |
|
stopwords_min_cutoff=self.param["stopwords_min_cutoff"], |
|
cond_check_badwords=self.param["cond_check_badwords"], |
|
badwords=self.badwords, |
|
badwords_max_cutoff=self.param["badwords_max_cutoff"], |
|
cond_check_lang_id=self.param["cond_check_lang_id"], |
|
lang_dataset_id=self.lang_dataset_id, |
|
model_lang_id=self.model_lang_id, |
|
lang_id_min_cutoff=self.param["lang_id_min_cutoff"], |
|
cond_check_perplexity=self.param["cond_check_perplexity"], |
|
sentencepiece_model=self.sentencepiece_model, |
|
kenlm_model=self.kenlm_model, |
|
perplexity_max_cutoff=self.param["perplexity_max_cutoff"], |
|
) |
|
return keep_example |
|
|
|
def __reduce__(self): |
|
return ( |
|
self.__class__, |
|
( |
|
self.lang_dataset_id, |
|
self.path_fasttext_model, |
|
self.path_sentencepiece_model, |
|
self.path_kenlm_model, |
|
), |
|
) |
|
|
|
|
|
class DatasetFiltering: |
|
def __init__( |
|
self, |
|
dataset, |
|
lang_dataset_id, |
|
path_fasttext_model, |
|
path_sentencepiece_model, |
|
path_kenlm_model, |
|
num_proc, |
|
path_dir_save_dataset, |
|
): |
|
self.ds = dataset |
|
self.lang_dataset_id = lang_dataset_id |
|
self.path_fasttext_model = path_fasttext_model |
|
self.path_sentencepiece_model = path_sentencepiece_model |
|
self.path_kenlm_model = path_kenlm_model |
|
self.num_proc = num_proc |
|
self.path_dir_save_dataset = path_dir_save_dataset |
|
|
|
def modifying_documents(self): |
|
dataset_modifying_documents = FunctionDatasetModifyingDocuments( |
|
self.lang_dataset_id |
|
) |
|
self.ds = self.ds.map(dataset_modifying_documents, num_proc=self.num_proc) |
|
|
|
def filtering(self): |
|
func_dataset_filtering = FunctionDatasetFiltering( |
|
self.lang_dataset_id, |
|
self.path_fasttext_model, |
|
self.path_sentencepiece_model, |
|
self.path_kenlm_model, |
|
) |
|
self.ds = self.ds.filter(func_dataset_filtering, num_proc=self.num_proc) |
|
|
|
def save_dataset(self): |
|
pathlib.Path(self.path_dir_save_dataset).mkdir(parents=True, exist_ok=True) |
|
path_dir_save_dataset = pathlib.PurePath( |
|
self.path_dir_save_dataset, self.lang_dataset_id |
|
) |
|
pathlib.Path(path_dir_save_dataset).mkdir(parents=True, exist_ok=True) |
|
self.ds.save_to_disk(path_dir_save_dataset) |
|
|