|
from enum import Enum |
|
from typing import Dict, List, Optional, Sequence |
|
|
|
import torch |
|
from fairseq.data import Dictionary |
|
|
|
|
|
class EncoderLangtok(Enum): |
|
""" |
|
Prepend to the beginning of source sentence either the |
|
source or target language token. (src/tgt). |
|
""" |
|
|
|
src = "src" |
|
tgt = "tgt" |
|
|
|
|
|
class LangTokSpec(Enum): |
|
main = "main" |
|
mono_dae = "mono_dae" |
|
|
|
|
|
class LangTokStyle(Enum): |
|
multilingual = "multilingual" |
|
mbart = "mbart" |
|
|
|
|
|
@torch.jit.export |
|
def get_lang_tok( |
|
lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value |
|
) -> str: |
|
|
|
|
|
TOKEN_STYLES: Dict[str, str] = { |
|
LangTokStyle.mbart.value: "[{}]", |
|
LangTokStyle.multilingual.value: "__{}__", |
|
} |
|
|
|
if spec.endswith("dae"): |
|
lang = f"{lang}_dae" |
|
elif spec.endswith("mined"): |
|
lang = f"{lang}_mined" |
|
style = TOKEN_STYLES[lang_tok_style] |
|
return style.format(lang) |
|
|
|
|
|
def augment_dictionary( |
|
dictionary: Dictionary, |
|
language_list: List[str], |
|
lang_tok_style: str, |
|
langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), |
|
extra_data: Optional[Dict[str, str]] = None, |
|
) -> None: |
|
for spec in langtoks_specs: |
|
for language in language_list: |
|
dictionary.add_symbol( |
|
get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) |
|
) |
|
|
|
if lang_tok_style == LangTokStyle.mbart.value or ( |
|
extra_data is not None and LangTokSpec.mono_dae.value in extra_data |
|
): |
|
dictionary.add_symbol("<mask>") |
|
|