File size: 3,028 Bytes
87cf786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
from pathlib import Path
import re
from transformers import SpeechT5Tokenizer
from transformers.models.speecht5.tokenization_speecht5 import (
    PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES,
)
from itertools import chain
from typing import List, Optional


def _g2p_with_np(text: str, np_lsit: str) -> List[str]:
    from pyopenjtalk import g2p

    np_pattern = re.compile(f"([{re.escape(np_lsit)}])")

    return list(
        chain.from_iterable(
            [
                (text,) if text in np_lsit else g2p(text, kana=False, join=False)
                for text in np_pattern.split(text)
                if len(text) > 0
            ]
        )
    )


NP_CHARCTERS = " !\"#$%&'()=~|`{+*}<>?_-^\\@[;:],./ !”#$%&’()=~|`{+*}<>?_ー^¥@「;:」、。・`"


class SpeechT5OpenjtalkTokenizer(SpeechT5Tokenizer):
    vocab_files_names = {"vocab_file": "spm_char.model"}
    pretrained_vocab_files_map = {}
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    model_input_names = ["input_ids", "attention_mask"]
    label2id = {}

    def __init__(
        self,
        vocab_file,
        bos_token: str = "<s>",
        eos_token: str = "</s>",
        unk_token: str = "<unk>",
        pad_token: str = "<pad>",
        non_phenome_characters: str = NP_CHARCTERS,
        **kwargs,
    ):
        try:
            super().__init__(
                vocab_file=None,
                bos_token=bos_token,
                eos_token=eos_token,
                unk_token=unk_token,
                pad_token=pad_token,
                **kwargs,
            )
        except TypeError:
            pass

        self.non_phenome_characters = non_phenome_characters

        if isinstance(vocab_file, str) and vocab_file.endswith(".json"):
            with open(vocab_file, encoding="utf-8") as f:
                self.label2id = json.load(f)
            self.id2label = {v: k for k, v in self.label2id.items()}

    @property
    def bos_token_id(self) -> int | None:
        return super().bos_token_id

    @property
    def vocab_size(self):
        return len(self.label2id)

    def get_vocab(self):
        return self.label2id

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ):
        if filename_prefix is None:
            filename_prefix = ".json"
        vocab_path = Path(save_directory) / Path(f"vocab{filename_prefix}")
        vocab_path.parent.mkdir(parents=True, exist_ok=True)
        with open(vocab_path, "w", encoding="utf-8") as f:
            json.dump(self.label2id, f, ensure_ascii=False, indent=2)
        return str(vocab_path), None

    def _tokenize(self, text: str) -> List[str]:
        return _g2p_with_np(text, self.non_phenome_characters)

    def _convert_token_to_id(self, token):
        return self.label2id.get(token, self.label2id.get(self.unk_token))

    def _convert_id_to_token(self, index):
        return self.id2label.get(index, self.unk_token)