|
from onmt.transforms import register_transform |
|
from .transform import Transform |
|
from onmt.utils.logging import logger |
|
import fasttext |
|
import os |
|
import urllib.request |
|
import regex as re |
|
|
|
|
|
@register_transform(name="clean") |
|
class CleanTransform(Transform): |
|
""" |
|
Clean examples according to rules |
|
|
|
""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Add an option for the corpus ratio to apply this transform.""" |
|
|
|
group = parser.add_argument_group("Transform/Clean") |
|
group.add( |
|
"--src_eq_tgt", |
|
"-src_eq_tgt", |
|
action="store_true", |
|
help="Remove ex src==tgt", |
|
) |
|
group.add( |
|
"--same_char", |
|
"-same_char", |
|
action="store_true", |
|
help="Remove ex with same char more than 4 times", |
|
) |
|
group.add( |
|
"--same_word", |
|
"-same_word", |
|
action="store_true", |
|
help="Remove ex with same word more than 3 times", |
|
) |
|
group.add( |
|
"--scripts_ok", |
|
"-scripts_ok", |
|
default=["Latin", "Common"], |
|
nargs="*", |
|
type=str, |
|
help="list of unicodata scripts accepted", |
|
) |
|
group.add( |
|
"--scripts_nok", |
|
"-scripts_nok", |
|
default=[], |
|
nargs="*", |
|
type=str, |
|
help="list of unicodata scripts not accepted", |
|
) |
|
|
|
group.add( |
|
"--src_tgt_ratio", |
|
"-src_tgt_ratio", |
|
type=float, |
|
default=2, |
|
help="ratio between src and tgt", |
|
) |
|
group.add( |
|
"--avg_tok_min", |
|
"-avg_tok_min", |
|
type=float, |
|
default=3, |
|
help="average length of tokens min", |
|
) |
|
group.add( |
|
"--avg_tok_max", |
|
"-avg_tok_max", |
|
type=float, |
|
default=20, |
|
help="average length of tokens max", |
|
) |
|
group.add( |
|
"--langid", |
|
"-langid", |
|
default=[], |
|
nargs="*", |
|
type=str, |
|
help="list of languages accepted", |
|
) |
|
|
|
def _parse_opts(self): |
|
self.src_eq_tgt = self.opts.src_eq_tgt |
|
self.same_char = self.opts.same_char |
|
self.same_word = self.opts.same_word |
|
self.scripts_ok = self.opts.scripts_ok |
|
self.scripts_nok = self.opts.scripts_nok |
|
self.src_tgt_ratio = self.opts.src_tgt_ratio |
|
self.avg_tok_min = self.opts.avg_tok_min |
|
self.avg_tok_max = self.opts.avg_tok_max |
|
self.langid = self.opts.langid |
|
assert ( |
|
self.scripts_ok == [] or self.scripts_nok == [] |
|
), "Choose either scripts to be included or excluded" |
|
|
|
@staticmethod |
|
def _get_opt(corpus, opt, def_val): |
|
"""Get opt string of a `corpus`.""" |
|
if "clean" in corpus["transforms"]: |
|
value = corpus.get(opt, def_val) |
|
clean = value |
|
else: |
|
clean = None |
|
return clean |
|
|
|
@classmethod |
|
def get_opt_dict(cls, opts, opt, def_val): |
|
"""Get clean settings correspond to corpus in `opts`.""" |
|
clean_dict = {} |
|
|
|
if hasattr(opts, "data"): |
|
for c_name, corpus in opts.data.items(): |
|
clean = cls._get_opt(corpus, opt, def_val) |
|
if clean is not None: |
|
logger.debug(f"Get {opt} for {c_name}: {clean}") |
|
clean_dict[c_name] = clean |
|
|
|
return clean_dict |
|
|
|
def warm_up(self, vocabs=None): |
|
super().warm_up(None) |
|
self.src_eq_tgt_dict = self.get_opt_dict(self.opts, "src_eq_tgt", True) |
|
self.same_char_dict = self.get_opt_dict(self.opts, "same_char", True) |
|
self.same_word_dict = self.get_opt_dict(self.opts, "same_word", True) |
|
self.scripts_ok_dict = self.get_opt_dict( |
|
self.opts, "scripts_ok", ["Latin", "Common"] |
|
) |
|
self.scripts_nok_dict = self.get_opt_dict(self.opts, "scripts_nok", []) |
|
self.src_tgt_ratio_dict = self.get_opt_dict(self.opts, "src_tgt_ratio", 2) |
|
self.avg_tok_min_dict = self.get_opt_dict(self.opts, "avg_tok_min", 3) |
|
self.avg_tok_max_dict = self.get_opt_dict(self.opts, "avg_tok_max", 20) |
|
self.langid_dict = self.get_opt_dict(self.opts, "langid", []) |
|
fasttext_loc = f"{os.path.dirname(os.path.abspath(__file__))}/lid.176.ftz" |
|
|
|
if not os.path.exists(fasttext_loc): |
|
urllib.request.urlretrieve( |
|
"https://dl.fbaipublicfiles.com/" |
|
+ "fasttext/supervised-models/lid.176.ftz", |
|
fasttext_loc, |
|
) |
|
self.id_func = fasttext.load_model(fasttext_loc) |
|
|
|
def batch_apply(self, batch, is_train=False, stats=None, **kwargs): |
|
"""Convert source and target examples to doc level segments.""" |
|
|
|
def _id(string): |
|
res = self.id_func.predict(string, k=1) |
|
res = res[0][0].replace("__label__", "") |
|
return res |
|
|
|
trf_batch = [] |
|
|
|
for ex, _, cid in batch: |
|
if self.scripts_ok_dict[cid]: |
|
ok_regex = ( |
|
"[^" |
|
+ "".join(r"\p{%s}" % sc for sc in self.scripts_ok_dict[cid]) |
|
+ "]" |
|
) |
|
|
|
if self.scripts_nok_dict[cid]: |
|
nok_regex = ( |
|
"[" |
|
+ "".join(r"\p{%s}" % sc for sc in self.scripts_nok_dict[cid]) |
|
+ "]" |
|
) |
|
|
|
src_str = " ".join(ex["src"]) |
|
if len(src_str) == 0: |
|
|
|
continue |
|
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", src_str): |
|
|
|
continue |
|
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", src_str): |
|
|
|
continue |
|
if len(src_str) / len(ex["src"]) < self.avg_tok_min_dict[cid]: |
|
|
|
continue |
|
if len(src_str) / len(ex["src"]) > self.avg_tok_max_dict[cid]: |
|
|
|
continue |
|
|
|
if self.scripts_ok_dict[cid] and re.search(ok_regex, src_str): |
|
|
|
continue |
|
if self.scripts_nok_dict[cid] and re.search(nok_regex, src_str): |
|
|
|
continue |
|
|
|
if ( |
|
self.langid_dict[cid] != [] |
|
and _id(src_str) not in self.langid_dict[cid] |
|
): |
|
|
|
continue |
|
|
|
if ex["tgt"] is not None: |
|
tgt_str = " ".join(ex["tgt"]) |
|
if self.src_eq_tgt_dict[cid] and src_str == tgt_str: |
|
|
|
continue |
|
if len(tgt_str) == 0: |
|
|
|
continue |
|
if (len(ex["src"]) + 1) / ( |
|
len(ex["tgt"]) + 1 |
|
) > self.src_tgt_ratio_dict[cid] or (len(ex["src"]) + 1) / ( |
|
len(ex["tgt"]) + 1 |
|
) < ( |
|
1 / self.src_tgt_ratio_dict[cid] |
|
): |
|
|
|
continue |
|
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", tgt_str): |
|
|
|
continue |
|
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", tgt_str): |
|
|
|
continue |
|
if len(tgt_str) / len(ex["tgt"]) < self.avg_tok_min_dict[cid]: |
|
|
|
continue |
|
if len(tgt_str) / len(ex["tgt"]) > self.avg_tok_max_dict[cid]: |
|
|
|
continue |
|
|
|
if self.scripts_ok_dict[cid] and re.search(ok_regex, tgt_str): |
|
|
|
continue |
|
if self.scripts_nok_dict[cid] and re.search(nok_regex, tgt_str): |
|
|
|
continue |
|
|
|
if ( |
|
self.langid_dict[cid] != [] |
|
and _id(tgt_str) not in self.langid_dict[cid] |
|
): |
|
|
|
continue |
|
|
|
trf_batch.append((ex, self, cid)) |
|
|
|
return trf_batch |
|
|