Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
9.15 kB
from onmt.transforms import register_transform
from .transform import Transform
from onmt.utils.logging import logger
import fasttext
import os
import urllib.request
import regex as re
@register_transform(name="clean")
class CleanTransform(Transform):
"""
Clean examples according to rules
"""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Add an option for the corpus ratio to apply this transform."""
group = parser.add_argument_group("Transform/Clean")
group.add(
"--src_eq_tgt",
"-src_eq_tgt",
action="store_true",
help="Remove ex src==tgt",
)
group.add(
"--same_char",
"-same_char",
action="store_true",
help="Remove ex with same char more than 4 times",
)
group.add(
"--same_word",
"-same_word",
action="store_true",
help="Remove ex with same word more than 3 times",
)
group.add(
"--scripts_ok",
"-scripts_ok",
default=["Latin", "Common"],
nargs="*",
type=str,
help="list of unicodata scripts accepted",
)
group.add(
"--scripts_nok",
"-scripts_nok",
default=[],
nargs="*",
type=str,
help="list of unicodata scripts not accepted",
)
group.add(
"--src_tgt_ratio",
"-src_tgt_ratio",
type=float,
default=2,
help="ratio between src and tgt",
)
group.add(
"--avg_tok_min",
"-avg_tok_min",
type=float,
default=3,
help="average length of tokens min",
)
group.add(
"--avg_tok_max",
"-avg_tok_max",
type=float,
default=20,
help="average length of tokens max",
)
group.add(
"--langid",
"-langid",
default=[],
nargs="*",
type=str,
help="list of languages accepted",
)
def _parse_opts(self):
self.src_eq_tgt = self.opts.src_eq_tgt
self.same_char = self.opts.same_char
self.same_word = self.opts.same_word
self.scripts_ok = self.opts.scripts_ok
self.scripts_nok = self.opts.scripts_nok
self.src_tgt_ratio = self.opts.src_tgt_ratio
self.avg_tok_min = self.opts.avg_tok_min
self.avg_tok_max = self.opts.avg_tok_max
self.langid = self.opts.langid
assert (
self.scripts_ok == [] or self.scripts_nok == []
), "Choose either scripts to be included or excluded"
@staticmethod
def _get_opt(corpus, opt, def_val):
"""Get opt string of a `corpus`."""
if "clean" in corpus["transforms"]:
value = corpus.get(opt, def_val)
clean = value
else:
clean = None
return clean
@classmethod
def get_opt_dict(cls, opts, opt, def_val):
"""Get clean settings correspond to corpus in `opts`."""
clean_dict = {}
# normalize dict src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
clean = cls._get_opt(corpus, opt, def_val)
if clean is not None:
logger.debug(f"Get {opt} for {c_name}: {clean}")
clean_dict[c_name] = clean
return clean_dict
def warm_up(self, vocabs=None):
super().warm_up(None)
self.src_eq_tgt_dict = self.get_opt_dict(self.opts, "src_eq_tgt", True)
self.same_char_dict = self.get_opt_dict(self.opts, "same_char", True)
self.same_word_dict = self.get_opt_dict(self.opts, "same_word", True)
self.scripts_ok_dict = self.get_opt_dict(
self.opts, "scripts_ok", ["Latin", "Common"]
)
self.scripts_nok_dict = self.get_opt_dict(self.opts, "scripts_nok", [])
self.src_tgt_ratio_dict = self.get_opt_dict(self.opts, "src_tgt_ratio", 2)
self.avg_tok_min_dict = self.get_opt_dict(self.opts, "avg_tok_min", 3)
self.avg_tok_max_dict = self.get_opt_dict(self.opts, "avg_tok_max", 20)
self.langid_dict = self.get_opt_dict(self.opts, "langid", [])
fasttext_loc = f"{os.path.dirname(os.path.abspath(__file__))}/lid.176.ftz"
if not os.path.exists(fasttext_loc):
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/"
+ "fasttext/supervised-models/lid.176.ftz",
fasttext_loc,
)
self.id_func = fasttext.load_model(fasttext_loc)
def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
"""Convert source and target examples to doc level segments."""
def _id(string):
res = self.id_func.predict(string, k=1)
res = res[0][0].replace("__label__", "")
return res
trf_batch = []
for ex, _, cid in batch:
if self.scripts_ok_dict[cid]:
ok_regex = (
"[^"
+ "".join(r"\p{%s}" % sc for sc in self.scripts_ok_dict[cid])
+ "]"
)
if self.scripts_nok_dict[cid]:
nok_regex = (
"["
+ "".join(r"\p{%s}" % sc for sc in self.scripts_nok_dict[cid])
+ "]"
)
src_str = " ".join(ex["src"])
if len(src_str) == 0:
# print("src empty")
continue
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", src_str):
# print("too many same char in src")
continue
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", src_str):
# print("too many same word in src")
continue
if len(src_str) / len(ex["src"]) < self.avg_tok_min_dict[cid]:
# print("avg token min", len(src_str) / len(ex['src']))
continue
if len(src_str) / len(ex["src"]) > self.avg_tok_max_dict[cid]:
# print("avg token max", len(src_str) / len(ex['src']))
continue
if self.scripts_ok_dict[cid] and re.search(ok_regex, src_str):
# print("text does not fully belong to wanted script")
continue
if self.scripts_nok_dict[cid] and re.search(nok_regex, src_str):
# print("Some text belong to unwanted scripts")
continue
if (
self.langid_dict[cid] != []
and _id(src_str) not in self.langid_dict[cid]
):
# print("langid does not match", _id(src_str))
continue
if ex["tgt"] is not None:
tgt_str = " ".join(ex["tgt"])
if self.src_eq_tgt_dict[cid] and src_str == tgt_str:
# print("src = tgt")
continue
if len(tgt_str) == 0:
# print("tgt empty")
continue
if (len(ex["src"]) + 1) / (
len(ex["tgt"]) + 1
) > self.src_tgt_ratio_dict[cid] or (len(ex["src"]) + 1) / (
len(ex["tgt"]) + 1
) < (
1 / self.src_tgt_ratio_dict[cid]
):
# print("src / tgt ratio ", len(src_str) / len(tgt_str))
continue
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", tgt_str):
# print("too many same char in tgt")
continue
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", tgt_str):
# print("too many same word in tgt")
continue
if len(tgt_str) / len(ex["tgt"]) < self.avg_tok_min_dict[cid]:
# print("avg token min", len(tgt_str) / len(ex['tgt']))
continue
if len(tgt_str) / len(ex["tgt"]) > self.avg_tok_max_dict[cid]:
# print("avg token max", len(tgt_str) / len(ex['tgt']))
continue
if self.scripts_ok_dict[cid] and re.search(ok_regex, tgt_str):
# print("text does not fully belong to wanted script")
continue
if self.scripts_nok_dict[cid] and re.search(nok_regex, tgt_str):
# print("Some text belong to unwanted scripts")
continue
if (
self.langid_dict[cid] != []
and _id(tgt_str) not in self.langid_dict[cid]
):
# print("langid does not match", _id(tgt_str))
continue
trf_batch.append((ex, self, cid))
return trf_batch