File size: 3,566 Bytes
b7cd722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import re

import cn2an
import opencc

from text.symbols import punctuation, sh_symbols

converter = opencc.OpenCC('text/lexicon/zaonhe.json')


def number_to_shanghainese(text):
    def to_shanghainese(num):
        num = cn2an.an2cn(num).replace('一十', '十').replace('二十', '廿').replace('二', '两')
        return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num)

    return re.sub(r'\d+(?:\.?\d+)?', lambda x: to_shanghainese(x.group()), text)


rep_map = {
    ":": ",",
    ";": ",",
    ",": ",",
    "。": ".",
    "!": "!",
    "?": "?",
    "\n": ".",
    "·": ",",
    "、": ",",
    "...": "…",
    "$": ".",
    "“": "'",
    "”": "'",
    "‘": "'",
    "’": "'",
    "(": "'",
    ")": "'",
    "(": "'",
    ")": "'",
    "《": "'",
    "》": "'",
    "【": "'",
    "】": "'",
    "[": "'",
    "]": "'",
    "—": "-",
    "~": "-",
    "~": "-",
    "「": "'",
    "」": "'",
}


def replace_punctuation(text):
    text = text.replace("嗯", "恩").replace("呣", "母")
    pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))

    replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)

    replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)

    return replaced_text


def valid_tone_char(char):
    return ("SH_" + char) in sh_symbols \
        or char in punctuation \
        or char.isdigit() \
        or char.isspace()


def g2p(text):
    phones, tones, word2ph = _g2p(text)
    phones = ["_"] + phones + ["_"]
    tones = [0] + tones + [0]
    word2ph = [1] + word2ph + [1]
    return phones, tones, word2ph


def _g2p(text):
    phones = converter.convert(text).replace('-', '').replace('$', '')
    phones = "".join([i if valid_tone_char(i) else '' for i in phones])
    phone_chars = [i for i in phones]

    phones = []
    tones = []
    word2ph = []

    if len(phone_chars) == 0:
        return phones, tones, word2ph

    phone_start_pos = 0

    for pos in range(len(phone_chars)):
        char = phone_chars[pos]
        if char.isdigit():
            tone = int(char)
            word2ph = word2ph + [pos - phone_start_pos]
            for j in range(phone_start_pos, pos):
                tones = tones + [tone]
            phone_start_pos = pos + 1

        elif char in punctuation:
            if pos != phone_start_pos:
                word2ph = word2ph + [pos - phone_start_pos]
                for j in range(phone_start_pos, pos):
                    tones = tones + [0]
                pass
            phones = phones + [char]
            tones = tones + [0]
            word2ph = word2ph + [1]
            phone_start_pos = pos + 1

        else:
            phones = phones + [char]
        pass

    last_phone_char = phone_chars[-1]
    if not last_phone_char.isdigit() and last_phone_char not in punctuation:
        word2ph = word2ph + [len(phone_chars) - phone_start_pos]
        for j in range(phone_start_pos, len(phone_chars)):
            tones = tones + [0]
        pass

    # phones 加前缀 'SH'
    phones = ["SH_" + i if ("SH_" + i) in sh_symbols else i for i in phones]
    assert len(tones) == len(phones)
    assert sum(word2ph) == len(phones)
    return phones, tones, word2ph


def text_normalize(text):
    text = number_to_shanghainese(text.upper())
    text = replace_punctuation(text)
    return text


def get_bert_feature(text, word2ph, device):
    from text import shanghainese_bert
    return shanghainese_bert.get_bert_feature(text, word2ph, device)