|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from onmt.utils.logging import logger |
|
import re |
|
import regex |
|
|
|
from itertools import chain |
|
|
|
from onmt.transforms import register_transform |
|
from .transform import Transform |
|
|
|
|
|
class MosesPunctNormalizer: |
|
""" |
|
This is a Python port of the Moses punctuation normalizer |
|
""" |
|
|
|
EXTRA_WHITESPACE = [ |
|
(r"\r", r""), |
|
(r"\(", r" ("), |
|
(r"\)", r") "), |
|
(r" +", r" "), |
|
(r"\) ([.!:?;,])", r")\g<1>"), |
|
(r"\( ", r"("), |
|
(r" \)", r")"), |
|
(r"(\d) %", r"\g<1>%"), |
|
(r" :", r":"), |
|
(r" ;", r";"), |
|
] |
|
|
|
NORMALIZE_UNICODE_IF_NOT_PENN = [(r"`", r"'"), (r"''", r' " ')] |
|
|
|
NORMALIZE_UNICODE = [ |
|
("„", r'"'), |
|
("“", r'"'), |
|
("”", r'"'), |
|
("–", r"-"), |
|
("—", r" - "), |
|
(r" +", r" "), |
|
("´", r"'"), |
|
("([a-zA-Z])‘([a-zA-Z])", r"\g<1>'\g<2>"), |
|
("([a-zA-Z])’([a-zA-Z])", r"\g<1>'\g<2>"), |
|
("‘", r"'"), |
|
("‚", r"'"), |
|
("’", r"'"), |
|
(r"''", r'"'), |
|
("´´", r'"'), |
|
("…", r"..."), |
|
] |
|
|
|
FRENCH_QUOTES = [ |
|
("\u00A0«\u00A0", r'"'), |
|
("«\u00A0", r'"'), |
|
("«", r'"'), |
|
("\u00A0»\u00A0", r'"'), |
|
("\u00A0»", r'"'), |
|
("»", r'"'), |
|
] |
|
|
|
HANDLE_PSEUDO_SPACES = [ |
|
("\u00A0%", r"%"), |
|
("nº\u00A0", "nº "), |
|
("\u00A0:", r":"), |
|
("\u00A0ºC", " ºC"), |
|
("\u00A0cm", r" cm"), |
|
("\u00A0\\?", "?"), |
|
("\u00A0\\!", "!"), |
|
("\u00A0;", r";"), |
|
(",\u00A0", r", "), |
|
(r" +", r" "), |
|
] |
|
|
|
EN_QUOTATION_FOLLOWED_BY_COMMA = [(r'"([,.]+)', r'\g<1>"')] |
|
|
|
DE_ES_FR_QUOTATION_FOLLOWED_BY_COMMA = [ |
|
(r',"', r'",'), |
|
(r'(\.+)"(\s*[^<])', r'"\g<1>\g<2>'), |
|
|
|
] |
|
|
|
DE_ES_CZ_CS_FR = [ |
|
("(\\d)\u00A0(\\d)", r"\g<1>,\g<2>"), |
|
] |
|
|
|
OTHER = [ |
|
("(\\d)\u00A0(\\d)", r"\g<1>.\g<2>"), |
|
] |
|
|
|
|
|
|
|
|
|
REPLACE_UNICODE_PUNCTUATION = [ |
|
(",", ","), |
|
(r"。\s*", ". "), |
|
("、", ","), |
|
("”", '"'), |
|
("“", '"'), |
|
("∶", ":"), |
|
(":", ":"), |
|
("?", "?"), |
|
("《", '"'), |
|
("》", '"'), |
|
(")", ")"), |
|
("!", "!"), |
|
("(", "("), |
|
(";", ";"), |
|
("」", '"'), |
|
("「", '"'), |
|
("0", "0"), |
|
("1", "1"), |
|
("2", "2"), |
|
("3", "3"), |
|
("4", "4"), |
|
("5", "5"), |
|
("6", "6"), |
|
("7", "7"), |
|
("8", "8"), |
|
("9", "9"), |
|
(r".\s*", ". "), |
|
("~", "~"), |
|
("’", "'"), |
|
("…", "..."), |
|
("━", "-"), |
|
("〈", "<"), |
|
("〉", ">"), |
|
("【", "["), |
|
("】", "]"), |
|
("%", "%"), |
|
] |
|
|
|
def __init__(self): |
|
""" """ |
|
|
|
def normalize( |
|
self, |
|
text, |
|
lang="en", |
|
penn=True, |
|
norm_quote_commas=True, |
|
norm_numbers=True, |
|
pre_replace_unicode_punct=False, |
|
post_remove_control_chars=False, |
|
): |
|
""" |
|
Returns a string with normalized punctuation. |
|
:param language: The two-letter language code. |
|
:type lang: str |
|
:param penn: Normalize Penn Treebank style quotations. |
|
:type penn: bool |
|
:param norm_quote_commas: Normalize quotations and commas |
|
:type norm_quote_commas: bool |
|
:param norm_numbers: Normalize numbers |
|
:type norm_numbers: bool |
|
""" |
|
self.substitutions = [ |
|
self.EXTRA_WHITESPACE, |
|
self.NORMALIZE_UNICODE, |
|
self.FRENCH_QUOTES, |
|
self.HANDLE_PSEUDO_SPACES, |
|
] |
|
|
|
if penn: |
|
|
|
self.substitutions.insert(1, self.NORMALIZE_UNICODE_IF_NOT_PENN) |
|
|
|
if norm_quote_commas: |
|
if lang == "en": |
|
self.substitutions.append(self.EN_QUOTATION_FOLLOWED_BY_COMMA) |
|
elif lang in ["de", "es", "fr"]: |
|
self.substitutions.append(self.DE_ES_FR_QUOTATION_FOLLOWED_BY_COMMA) |
|
|
|
if norm_numbers: |
|
if lang in ["de", "es", "cz", "cs", "fr"]: |
|
self.substitutions.append(self.DE_ES_CZ_CS_FR) |
|
else: |
|
self.substitutions.append(self.OTHER) |
|
|
|
self.substitutions = list(chain(*self.substitutions)) |
|
|
|
|
|
if pre_replace_unicode_punct: |
|
text = self.replace_unicode_punct(text) |
|
|
|
|
|
for regexp, substitution in self.substitutions: |
|
text = re.sub(regexp, substitution, str(text)) |
|
|
|
|
|
if post_remove_control_chars: |
|
text = self.remove_control_chars(text) |
|
|
|
return text.strip() |
|
|
|
def replace_unicode_punct(self, text): |
|
for regexp, substitution in self.REPLACE_UNICODE_PUNCTUATION: |
|
text = re.sub(regexp, substitution, str(text)) |
|
return text |
|
|
|
def remove_control_chars(self, text): |
|
return regex.sub(r"\p{C}", "", text) |
|
|
|
|
|
@register_transform(name="normalize") |
|
class NormalizeTransform(Transform): |
|
""" |
|
Normalize source and target based on Moses script. |
|
""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Add an option for the corpus ratio to apply this transform.""" |
|
|
|
group = parser.add_argument_group("Transform/Normalize") |
|
group.add( |
|
"--src_lang", "-src_lang", type=str, default="", help="Source language code" |
|
) |
|
group.add( |
|
"--tgt_lang", "-tgt_lang", type=str, default="", help="Target language code" |
|
) |
|
group.add("--penn", "-penn", type=bool, default=True, help="Penn substitution") |
|
group.add( |
|
"--norm_quote_commas", |
|
"-norm_quote_commas", |
|
type=bool, |
|
default=True, |
|
help="Normalize quotations and commas", |
|
) |
|
group.add( |
|
"--norm_numbers", |
|
"-norm_numbers", |
|
type=bool, |
|
default=True, |
|
help="Normalize numbers", |
|
) |
|
group.add( |
|
"--pre_replace_unicode_punct", |
|
"-pre_replace_unicode_punct", |
|
type=bool, |
|
default=False, |
|
help="Replace unicode punct", |
|
) |
|
group.add( |
|
"--post_remove_control_chars", |
|
"-post_remove_control_chars", |
|
type=bool, |
|
default=False, |
|
help="Remove control chars", |
|
) |
|
|
|
def _parse_opts(self): |
|
self.src_lang = self.opts.src_lang |
|
self.tgt_lang = self.opts.tgt_lang |
|
self.penn = self.opts.penn |
|
self.norm_quote_commas = self.opts.norm_quote_commas |
|
self.norm_numbers = self.opts.norm_numbers |
|
self.pre_replace_unicode_punct = self.opts.pre_replace_unicode_punct |
|
self.post_remove_control_chars = self.opts.post_remove_control_chars |
|
|
|
@staticmethod |
|
def _get_opt(corpus, opt, def_val): |
|
"""Get opt string of a `corpus`.""" |
|
if "normalize" in corpus["transforms"]: |
|
value = corpus.get(opt, def_val) |
|
normalize = value |
|
else: |
|
normalize = None |
|
return normalize |
|
|
|
@classmethod |
|
def get_opt_dict(cls, opts, opt, def_val): |
|
"""Get normalize settings correspond to corpus in `opts`.""" |
|
normalize_dict = {} |
|
|
|
if hasattr(opts, "data"): |
|
for c_name, corpus in opts.data.items(): |
|
normalize = cls._get_opt(corpus, opt, def_val) |
|
if normalize is not None: |
|
logger.debug(f"Get {opt} for {c_name}: {normalize}") |
|
normalize_dict[c_name] = normalize |
|
|
|
return normalize_dict |
|
|
|
def warm_up(self, vocabs=None): |
|
"""Set options for each dataset.""" |
|
super().warm_up(None) |
|
self.src_lang_dict = self.get_opt_dict(self.opts, "src_lang", "") |
|
self.tgt_lang_dict = self.get_opt_dict(self.opts, "tgt_lang", "") |
|
self.penn_dict = self.get_opt_dict(self.opts, "penn", True) |
|
self.norm_quote_commas_dict = self.get_opt_dict( |
|
self.opts, "norm_quote_commas", True |
|
) |
|
self.norm_numbers_dict = self.get_opt_dict(self.opts, "norm_numbers", True) |
|
self.pre_dict = self.get_opt_dict(self.opts, "pre_replace_unicode_punct", False) |
|
self.post_dict = self.get_opt_dict( |
|
self.opts, "post_remove_control_chars", False |
|
) |
|
self.src_lang_dict["infer"] = self.src_lang |
|
self.tgt_lang_dict["infer"] = self.tgt_lang |
|
self.penn_dict["infer"] = self.penn |
|
self.norm_quote_commas_dict["infer"] = self.norm_quote_commas |
|
self.norm_numbers_dict["infer"] = self.norm_numbers |
|
self.pre_dict["infer"] = self.pre_replace_unicode_punct |
|
self.post_dict["infer"] = self.post_remove_control_chars |
|
|
|
self.mpn = MosesPunctNormalizer() |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
"""Normalize source and target examples.""" |
|
corpus_name = kwargs.get("corpus_name", None) |
|
if corpus_name is None: |
|
raise ValueError("corpus_name is required.") |
|
|
|
src_str = self.mpn.normalize( |
|
" ".join(example["src"]), |
|
self.src_lang_dict[corpus_name], |
|
self.penn_dict[corpus_name], |
|
self.norm_quote_commas_dict[corpus_name], |
|
self.norm_numbers_dict[corpus_name], |
|
self.pre_dict[corpus_name], |
|
self.post_dict[corpus_name], |
|
) |
|
example["src"] = src_str.split() |
|
|
|
if example["tgt"] is not None: |
|
tgt_str = self.mpn.normalize( |
|
" ".join(example["tgt"]), |
|
self.tgt_lang_dict[corpus_name], |
|
self.penn_dict[corpus_name], |
|
self.norm_quote_commas_dict[corpus_name], |
|
self.norm_numbers_dict[corpus_name], |
|
self.pre_dict[corpus_name], |
|
self.post_dict[corpus_name], |
|
) |
|
example["tgt"] = tgt_str.split() |
|
|
|
return example |
|
|