File size: 9,149 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
from onmt.transforms import register_transform
from .transform import Transform
from onmt.utils.logging import logger
import fasttext
import os
import urllib.request
import regex as re
@register_transform(name="clean")
class CleanTransform(Transform):
"""
Clean examples according to rules
"""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Add an option for the corpus ratio to apply this transform."""
group = parser.add_argument_group("Transform/Clean")
group.add(
"--src_eq_tgt",
"-src_eq_tgt",
action="store_true",
help="Remove ex src==tgt",
)
group.add(
"--same_char",
"-same_char",
action="store_true",
help="Remove ex with same char more than 4 times",
)
group.add(
"--same_word",
"-same_word",
action="store_true",
help="Remove ex with same word more than 3 times",
)
group.add(
"--scripts_ok",
"-scripts_ok",
default=["Latin", "Common"],
nargs="*",
type=str,
help="list of unicodata scripts accepted",
)
group.add(
"--scripts_nok",
"-scripts_nok",
default=[],
nargs="*",
type=str,
help="list of unicodata scripts not accepted",
)
group.add(
"--src_tgt_ratio",
"-src_tgt_ratio",
type=float,
default=2,
help="ratio between src and tgt",
)
group.add(
"--avg_tok_min",
"-avg_tok_min",
type=float,
default=3,
help="average length of tokens min",
)
group.add(
"--avg_tok_max",
"-avg_tok_max",
type=float,
default=20,
help="average length of tokens max",
)
group.add(
"--langid",
"-langid",
default=[],
nargs="*",
type=str,
help="list of languages accepted",
)
def _parse_opts(self):
self.src_eq_tgt = self.opts.src_eq_tgt
self.same_char = self.opts.same_char
self.same_word = self.opts.same_word
self.scripts_ok = self.opts.scripts_ok
self.scripts_nok = self.opts.scripts_nok
self.src_tgt_ratio = self.opts.src_tgt_ratio
self.avg_tok_min = self.opts.avg_tok_min
self.avg_tok_max = self.opts.avg_tok_max
self.langid = self.opts.langid
assert (
self.scripts_ok == [] or self.scripts_nok == []
), "Choose either scripts to be included or excluded"
@staticmethod
def _get_opt(corpus, opt, def_val):
"""Get opt string of a `corpus`."""
if "clean" in corpus["transforms"]:
value = corpus.get(opt, def_val)
clean = value
else:
clean = None
return clean
@classmethod
def get_opt_dict(cls, opts, opt, def_val):
"""Get clean settings correspond to corpus in `opts`."""
clean_dict = {}
# normalize dict src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
clean = cls._get_opt(corpus, opt, def_val)
if clean is not None:
logger.debug(f"Get {opt} for {c_name}: {clean}")
clean_dict[c_name] = clean
return clean_dict
def warm_up(self, vocabs=None):
super().warm_up(None)
self.src_eq_tgt_dict = self.get_opt_dict(self.opts, "src_eq_tgt", True)
self.same_char_dict = self.get_opt_dict(self.opts, "same_char", True)
self.same_word_dict = self.get_opt_dict(self.opts, "same_word", True)
self.scripts_ok_dict = self.get_opt_dict(
self.opts, "scripts_ok", ["Latin", "Common"]
)
self.scripts_nok_dict = self.get_opt_dict(self.opts, "scripts_nok", [])
self.src_tgt_ratio_dict = self.get_opt_dict(self.opts, "src_tgt_ratio", 2)
self.avg_tok_min_dict = self.get_opt_dict(self.opts, "avg_tok_min", 3)
self.avg_tok_max_dict = self.get_opt_dict(self.opts, "avg_tok_max", 20)
self.langid_dict = self.get_opt_dict(self.opts, "langid", [])
fasttext_loc = f"{os.path.dirname(os.path.abspath(__file__))}/lid.176.ftz"
if not os.path.exists(fasttext_loc):
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/"
+ "fasttext/supervised-models/lid.176.ftz",
fasttext_loc,
)
self.id_func = fasttext.load_model(fasttext_loc)
def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
"""Convert source and target examples to doc level segments."""
def _id(string):
res = self.id_func.predict(string, k=1)
res = res[0][0].replace("__label__", "")
return res
trf_batch = []
for ex, _, cid in batch:
if self.scripts_ok_dict[cid]:
ok_regex = (
"[^"
+ "".join(r"\p{%s}" % sc for sc in self.scripts_ok_dict[cid])
+ "]"
)
if self.scripts_nok_dict[cid]:
nok_regex = (
"["
+ "".join(r"\p{%s}" % sc for sc in self.scripts_nok_dict[cid])
+ "]"
)
src_str = " ".join(ex["src"])
if len(src_str) == 0:
# print("src empty")
continue
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", src_str):
# print("too many same char in src")
continue
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", src_str):
# print("too many same word in src")
continue
if len(src_str) / len(ex["src"]) < self.avg_tok_min_dict[cid]:
# print("avg token min", len(src_str) / len(ex['src']))
continue
if len(src_str) / len(ex["src"]) > self.avg_tok_max_dict[cid]:
# print("avg token max", len(src_str) / len(ex['src']))
continue
if self.scripts_ok_dict[cid] and re.search(ok_regex, src_str):
# print("text does not fully belong to wanted script")
continue
if self.scripts_nok_dict[cid] and re.search(nok_regex, src_str):
# print("Some text belong to unwanted scripts")
continue
if (
self.langid_dict[cid] != []
and _id(src_str) not in self.langid_dict[cid]
):
# print("langid does not match", _id(src_str))
continue
if ex["tgt"] is not None:
tgt_str = " ".join(ex["tgt"])
if self.src_eq_tgt_dict[cid] and src_str == tgt_str:
# print("src = tgt")
continue
if len(tgt_str) == 0:
# print("tgt empty")
continue
if (len(ex["src"]) + 1) / (
len(ex["tgt"]) + 1
) > self.src_tgt_ratio_dict[cid] or (len(ex["src"]) + 1) / (
len(ex["tgt"]) + 1
) < (
1 / self.src_tgt_ratio_dict[cid]
):
# print("src / tgt ratio ", len(src_str) / len(tgt_str))
continue
if self.same_char_dict[cid] and re.search(r"([^0-9])\1{3}", tgt_str):
# print("too many same char in tgt")
continue
if self.same_word_dict[cid] and re.search(r"(\ .*|.*\ )\1{2}", tgt_str):
# print("too many same word in tgt")
continue
if len(tgt_str) / len(ex["tgt"]) < self.avg_tok_min_dict[cid]:
# print("avg token min", len(tgt_str) / len(ex['tgt']))
continue
if len(tgt_str) / len(ex["tgt"]) > self.avg_tok_max_dict[cid]:
# print("avg token max", len(tgt_str) / len(ex['tgt']))
continue
if self.scripts_ok_dict[cid] and re.search(ok_regex, tgt_str):
# print("text does not fully belong to wanted script")
continue
if self.scripts_nok_dict[cid] and re.search(nok_regex, tgt_str):
# print("Some text belong to unwanted scripts")
continue
if (
self.langid_dict[cid] != []
and _id(tgt_str) not in self.langid_dict[cid]
):
# print("langid does not match", _id(tgt_str))
continue
trf_batch.append((ex, self, cid))
return trf_batch
|