TomatoCocotree
上传
6a62ffb
from enum import Enum
from typing import Dict, List, Optional, Sequence
import torch
from fairseq.data import Dictionary
class EncoderLangtok(Enum):
"""
Prepend to the beginning of source sentence either the
source or target language token. (src/tgt).
"""
src = "src"
tgt = "tgt"
class LangTokSpec(Enum):
main = "main"
mono_dae = "mono_dae"
class LangTokStyle(Enum):
multilingual = "multilingual"
mbart = "mbart"
@torch.jit.export
def get_lang_tok(
lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value
) -> str:
# TOKEN_STYLES can't be defined outside this fn since it needs to be
# TorchScriptable.
TOKEN_STYLES: Dict[str, str] = {
LangTokStyle.mbart.value: "[{}]",
LangTokStyle.multilingual.value: "__{}__",
}
if spec.endswith("dae"):
lang = f"{lang}_dae"
elif spec.endswith("mined"):
lang = f"{lang}_mined"
style = TOKEN_STYLES[lang_tok_style]
return style.format(lang)
def augment_dictionary(
dictionary: Dictionary,
language_list: List[str],
lang_tok_style: str,
langtoks_specs: Sequence[str] = (LangTokSpec.main.value,),
extra_data: Optional[Dict[str, str]] = None,
) -> None:
for spec in langtoks_specs:
for language in language_list:
dictionary.add_symbol(
get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec)
)
if lang_tok_style == LangTokStyle.mbart.value or (
extra_data is not None and LangTokSpec.mono_dae.value in extra_data
):
dictionary.add_symbol("<mask>")