|
import re |
|
|
|
import cn2an |
|
import opencc |
|
|
|
from text.symbols import punctuation, sh_symbols |
|
|
|
converter = opencc.OpenCC('text/lexicon/zaonhe.json') |
|
|
|
|
|
def number_to_shanghainese(text): |
|
def to_shanghainese(num): |
|
num = cn2an.an2cn(num).replace('一十', '十').replace('二十', '廿').replace('二', '两') |
|
return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) |
|
|
|
return re.sub(r'\d+(?:\.?\d+)?', lambda x: to_shanghainese(x.group()), text) |
|
|
|
|
|
rep_map = { |
|
":": ",", |
|
";": ",", |
|
",": ",", |
|
"。": ".", |
|
"!": "!", |
|
"?": "?", |
|
"\n": ".", |
|
"·": ",", |
|
"、": ",", |
|
"...": "…", |
|
"$": ".", |
|
"“": "'", |
|
"”": "'", |
|
"‘": "'", |
|
"’": "'", |
|
"(": "'", |
|
")": "'", |
|
"(": "'", |
|
")": "'", |
|
"《": "'", |
|
"》": "'", |
|
"【": "'", |
|
"】": "'", |
|
"[": "'", |
|
"]": "'", |
|
"—": "-", |
|
"~": "-", |
|
"~": "-", |
|
"「": "'", |
|
"」": "'", |
|
} |
|
|
|
|
|
def replace_punctuation(text): |
|
text = text.replace("嗯", "恩").replace("呣", "母") |
|
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) |
|
|
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) |
|
|
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) |
|
|
|
return replaced_text |
|
|
|
|
|
def valid_tone_char(char): |
|
return ("SH_" + char) in sh_symbols \ |
|
or char in punctuation \ |
|
or char.isdigit() \ |
|
or char.isspace() |
|
|
|
|
|
def g2p(text): |
|
phones, tones, word2ph = _g2p(text) |
|
phones = ["_"] + phones + ["_"] |
|
tones = [0] + tones + [0] |
|
word2ph = [1] + word2ph + [1] |
|
return phones, tones, word2ph |
|
|
|
|
|
def _g2p(text): |
|
phones = converter.convert(text).replace('-', '').replace('$', '') |
|
phones = "".join([i if valid_tone_char(i) else '' for i in phones]) |
|
phone_chars = [i for i in phones] |
|
|
|
phones = [] |
|
tones = [] |
|
word2ph = [] |
|
|
|
if len(phone_chars) == 0: |
|
return phones, tones, word2ph |
|
|
|
phone_start_pos = 0 |
|
|
|
for pos in range(len(phone_chars)): |
|
char = phone_chars[pos] |
|
if char.isdigit(): |
|
tone = int(char) |
|
word2ph = word2ph + [pos - phone_start_pos] |
|
for j in range(phone_start_pos, pos): |
|
tones = tones + [tone] |
|
phone_start_pos = pos + 1 |
|
|
|
elif char in punctuation: |
|
if pos != phone_start_pos: |
|
word2ph = word2ph + [pos - phone_start_pos] |
|
for j in range(phone_start_pos, pos): |
|
tones = tones + [0] |
|
pass |
|
phones = phones + [char] |
|
tones = tones + [0] |
|
word2ph = word2ph + [1] |
|
phone_start_pos = pos + 1 |
|
|
|
else: |
|
phones = phones + [char] |
|
pass |
|
|
|
last_phone_char = phone_chars[-1] |
|
if not last_phone_char.isdigit() and last_phone_char not in punctuation: |
|
word2ph = word2ph + [len(phone_chars) - phone_start_pos] |
|
for j in range(phone_start_pos, len(phone_chars)): |
|
tones = tones + [0] |
|
pass |
|
|
|
|
|
phones = ["SH_" + i if ("SH_" + i) in sh_symbols else i for i in phones] |
|
assert len(tones) == len(phones) |
|
assert sum(word2ph) == len(phones) |
|
return phones, tones, word2ph |
|
|
|
|
|
def text_normalize(text): |
|
text = number_to_shanghainese(text.upper()) |
|
text = replace_punctuation(text) |
|
return text |
|
|
|
|
|
def get_bert_feature(text, word2ph, device): |
|
from text import shanghainese_bert |
|
return shanghainese_bert.get_bert_feature(text, word2ph, device) |
|
|