AkitoP's picture
Upload 182 files
e82212c verified
raw
history blame
5.01 kB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re
def wordize_and_map(text: str):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
if match_en:
en_word = match_en.group(0)
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + len(en_word)
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)]
words.append(text[0])
text = text[1:]
return words, index_map_from_text_to_word, index_map_from_word_to_text
def tokenize_and_map(tokenizer, text: str):
words, text2word, word2text = wordize_and_map(text=text)
tokens = []
index_map_from_token_to_text = []
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
index_map_from_text_to_token = text2word
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
for token_pos in range(token_start, token_end):
index_map_from_text_to_token[token_pos] = i
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path: os.PathLike, use_default: bool=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
if not hasattr(config, attr):
setattr(config, attr, val)
elif isinstance(val, dict):
d = getattr(config, attr)
for dict_k, dict_v in val.items():
if dict_k not in d:
d[dict_k] = dict_v
return config