Build error
Build error
File size: 13,666 Bytes
98f685a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
from copy import deepcopy
import torch
import dgl
import stanza
import networkx as nx
class Sentence2GraphParser:
def __init__(self, language='zh', use_gpu=False, download=False):
self.language = language
if download:
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu)
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu, download_method=None)
def parse(self, clean_sentence=None, words=None, ph_words=None):
if self.language == 'zh':
assert words is not None and ph_words is not None
ret = self._parse_zh(words, ph_words)
elif self.language == 'en':
assert clean_sentence is not None
ret = self._parse_en(clean_sentence)
raise NotImplementedError
return ret
def _parse_zh(self, words, ph_words, enable_backward_edge=True, enable_recur_edge=True,
enable_inter_sentence_edge=True, sequential_edge=False):
words: <List of str>, each character in chinese is one item
ph_words: <List of str>, each character in chinese is one item, represented by the phoneme
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.'
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ','
, '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>']
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#',
'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|',
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>']
words, ph_words = words[1:-1], ph_words[1:-1] # delete <BOS> and <EOS>
for i, p_w in enumerate(ph_words):
if p_w == ',':
# change english ',' into chinese
# we found it necessary in stanza's dependency parsing
words[i], ph_words[i] = ',', ','
tmp_words = deepcopy(words)
num_added_space = 0
for i, p_w in enumerate(ph_words):
if p_w.endswith("#"):
# add a blank after the p_w with '#', to separate words
tmp_words.insert(num_added_space + i + 1, " ")
num_added_space += 1
if p_w in [',', ',']:
# add one blank before and after ', ', respectively
tmp_words.insert(num_added_space + i + 1, " ") # insert behind ',' first
tmp_words.insert(num_added_space + i, " ") # insert before
num_added_space += 2
clean_text = ''.join(tmp_words).strip()
parser_out = self.stanza_parser(clean_text)
idx_to_word = {i + 1: w for i, w in enumerate(words)}
vocab_nodes = {}
vocab_idx_offset = 0
for sentence in parser_out.sentences:
num_nodes_in_current_sentence = 0
for vocab_node in sentence.words:
num_nodes_in_current_sentence += 1
vocab_idx = + vocab_idx_offset
vocab_text = vocab_node.text.replace(" ", "") # delete blank in vocab
vocab_nodes[vocab_idx] = vocab_text
vocab_idx_offset += num_nodes_in_current_sentence
# start vocab-to-word alignment
vocab_to_word = {}
current_word_idx = 1
for vocab_i in vocab_nodes.keys():
vocab_to_word[vocab_i] = []
for w_in_vocab_i in vocab_nodes[vocab_i]:
if w_in_vocab_i != idx_to_word[current_word_idx]:
raise ValueError("Word Mismatch!")
vocab_to_word[vocab_i].append(current_word_idx) # add a path (vocab_node_idx, word_global_idx)
current_word_idx += 1
# then we compute the vocab-level edges
if len(parser_out.sentences) > 5:
print("Detect more than 5 input sentence! pls check whether the sentence is too long!")
vocab_level_source_id, vocab_level_dest_id = [], []
vocab_level_edge_types = []
sentences_heads = []
vocab_id_offset = 0
# get forward edges
for s in parser_out.sentences:
for w in s.words:
w_idx = + vocab_id_offset # it starts from 1, just same as binarizer
w_dest_idx = w.head + vocab_id_offset
if w.head == 0:
vocab_id_offset += len(s.words)
vocab_level_edge_types += [0] * len(vocab_level_source_id)
num_vocab = vocab_id_offset
# optional: get backward edges
if enable_backward_edge:
back_source, back_dest = deepcopy(vocab_level_dest_id), deepcopy(vocab_level_source_id)
vocab_level_source_id += back_source
vocab_level_dest_id += back_dest
vocab_level_edge_types += [1] * len(back_source)
# optional: get inter-sentence edges if num_sentences > 1
inter_sentence_source, inter_sentence_dest = [], []
if enable_inter_sentence_edge and len(sentences_heads) > 1:
def get_full_graph_edges(nodes):
tmp_edges = []
for i, node_i in enumerate(nodes):
for j, node_j in enumerate(nodes):
if i == j:
tmp_edges.append((node_i, node_j))
return tmp_edges
tmp_edges = get_full_graph_edges(sentences_heads)
for (source, dest) in tmp_edges:
vocab_level_source_id += inter_sentence_source
vocab_level_dest_id += inter_sentence_dest
vocab_level_edge_types += [3] * len(inter_sentence_source)
if sequential_edge:
seq_source, seq_dest = list(range(1, num_vocab)) + list(range(num_vocab, 0, -1)), \
list(range(2, num_vocab + 1)) + list(range(num_vocab - 1, -1, -1))
vocab_level_source_id += seq_source
vocab_level_dest_id += seq_dest
vocab_level_edge_types += [4] * (num_vocab - 1) + [5] * (num_vocab - 1)
# Then, we use the vocab-level edges and the vocab-to-word path, to construct the word-level graph
num_word = len(words)
source_id, dest_id, edge_types = [], [], []
for (vocab_start, vocab_end, vocab_edge_type) in zip(vocab_level_source_id, vocab_level_dest_id,
# connect the first word in the vocab
word_start = min(vocab_to_word[vocab_start])
word_end = min(vocab_to_word[vocab_end])
# sequential connection in words
for word_indices_in_v in vocab_to_word.values():
for i, word_idx in enumerate(word_indices_in_v):
if i + 1 < len(word_indices_in_v):
dest_id.append(word_idx + 1)
if i - 1 >= 0:
dest_id.append(word_idx - 1)
# optional: get recurrent edges
if enable_recur_edge:
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1))
source_id += recur_source
dest_id += recur_dest
edge_types += [2] * len(recur_source)
# add <BOS> and <EOS>
source_id += [0, num_word + 1, 1, num_word]
dest_id += [1, num_word, 0, num_word + 1]
edge_types += [4, 4, 5, 5] # 4 represents sequentially forward, 5 is sequential backward
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id))
dgl_graph = dgl.graph(edges)
assert dgl_graph.num_edges() == len(edge_types)
return dgl_graph, torch.LongTensor(edge_types)
def _parse_en(self, clean_sentence, enable_backward_edge=True, enable_recur_edge=True,
enable_inter_sentence_edge=True, sequential_edge=False, consider_bos_for_index=True):
clean_sentence: <str>, each word or punctuation should be separated by one blank.
edge_types = [] # required for gated graph neural network
clean_sentence = clean_sentence.strip()
if clean_sentence.endswith((" .", " ,", " ;", " :", " ?", " !")):
clean_sentence = clean_sentence[:-2]
if clean_sentence.startswith(". "):
clean_sentence = clean_sentence[2:]
parser_out = self.stanza_parser(clean_sentence)
if len(parser_out.sentences) > 5:
print("Detect more than 5 input sentence! pls check whether the sentence is too long!")
source_id, dest_id = [], []
sentences_heads = []
word_id_offset = 0
# get forward edges
for s in parser_out.sentences:
for w in s.words:
w_idx = + word_id_offset # it starts from 1, just same as binarizer
w_dest_idx = w.head + word_id_offset
if w.head == 0:
word_id_offset += len(s.words)
num_word = word_id_offset
edge_types += [0] * len(source_id)
# optional: get backward edges
if enable_backward_edge:
back_source, back_dest = deepcopy(dest_id), deepcopy(source_id)
source_id += back_source
dest_id += back_dest
edge_types += [1] * len(back_source)
# optional: get recurrent edges
if enable_recur_edge:
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1))
source_id += recur_source
dest_id += recur_dest
edge_types += [2] * len(recur_source)
# optional: get inter-sentence edges if num_sentences > 1
inter_sentence_source, inter_sentence_dest = [], []
if enable_inter_sentence_edge and len(sentences_heads) > 1:
def get_full_graph_edges(nodes):
tmp_edges = []
for i, node_i in enumerate(nodes):
for j, node_j in enumerate(nodes):
if i == j:
tmp_edges.append((node_i, node_j))
return tmp_edges
tmp_edges = get_full_graph_edges(sentences_heads)
for (source, dest) in tmp_edges:
source_id += inter_sentence_source
dest_id += inter_sentence_dest
edge_types += [3] * len(inter_sentence_source)
# add <BOS> and <EOS>
source_id += [0, num_word + 1, 1, num_word]
dest_id += [1, num_word, 0, num_word + 1]
edge_types += [4, 4, 5, 5] # 4 represents sequentially forward, 5 is sequential backward
# optional: sequential edge
if sequential_edge:
seq_source, seq_dest = list(range(1, num_word)) + list(range(num_word, 0, -1)), \
list(range(2, num_word + 1)) + list(range(num_word - 1, -1, -1))
source_id += seq_source
dest_id += seq_dest
edge_types += [4] * (num_word - 1) + [5] * (num_word - 1)
if consider_bos_for_index:
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id))
edges = (torch.LongTensor(source_id) - 1, torch.LongTensor(dest_id) - 1)
dgl_graph = dgl.graph(edges)
assert dgl_graph.num_edges() == len(edge_types)
return dgl_graph, torch.LongTensor(edge_types)
def plot_dgl_sentence_graph(dgl_graph, labels):
labels = {idx: word for idx,word in enumerate(sentence.split(" ")) }
import matplotlib.pyplot as plt
nx_graph = dgl_graph.to_networkx()
pos = nx.random_layout(nx_graph)
nx.draw(nx_graph, pos, with_labels=False)
nx.draw_networkx_labels(nx_graph, pos, labels)
if __name__ == '__main__':
# Unit Test for Chinese Graph Builder
parser = Sentence2GraphParser("zh")
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.'
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',', '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>']
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', 'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|',
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>']
graph1, etypes1 = parser.parse(text1, words, ph_words)
plot_dgl_sentence_graph(graph1, {i: w for i, w in enumerate(ph_words)})
# Unit Test for English Graph Builder
parser = Sentence2GraphParser("en")
text2 = "I love you . You love me . Mixue ice-scream and tea ."
graph2, etypes2 = parser.parse(text2)
plot_dgl_sentence_graph(graph2, {i: w for i, w in enumerate(("<BOS> " + text2 + " <EOS>").split(" "))})