File size: 13,666 Bytes
98f685a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
from copy import deepcopy
import torch
import dgl
import stanza 
import networkx as nx

class Sentence2GraphParser:
    def __init__(self, language='zh', use_gpu=False, download=False):
        self.language = language
        if download:
            self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu)
        else:
            self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu, download_method=None)

    def parse(self, clean_sentence=None, words=None, ph_words=None):
        if self.language == 'zh':
            assert words is not None and ph_words is not None
            ret = self._parse_zh(words, ph_words)
        elif self.language == 'en':
            assert clean_sentence is not None
            ret = self._parse_en(clean_sentence)
        else:
            raise NotImplementedError
        return ret

    def _parse_zh(self, words, ph_words, enable_backward_edge=True, enable_recur_edge=True,
                  enable_inter_sentence_edge=True, sequential_edge=False):
        """
        words: <List of str>, each character in chinese is one item
        ph_words: <List of str>, each character in chinese is one item, represented by the phoneme
        Example:
                text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.'
                words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ','
                        , '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>']
                ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#',
                            'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|',
                            'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>']
        """
        words, ph_words = words[1:-1], ph_words[1:-1]  # delete <BOS> and <EOS>
        for i, p_w in enumerate(ph_words):
            if p_w == ',':
                # change english ',' into chinese
                # we found it necessary in stanza's dependency parsing
                words[i], ph_words[i] = ',', ','
        tmp_words = deepcopy(words)
        num_added_space = 0
        for i, p_w in enumerate(ph_words):
            if p_w.endswith("#"):
                # add a blank after the p_w with '#', to separate words
                tmp_words.insert(num_added_space + i + 1, " ")
                num_added_space += 1
            if p_w in [',', ',']:
                # add one blank before and after ', ', respectively
                tmp_words.insert(num_added_space + i + 1, " ")  # insert behind ',' first
                tmp_words.insert(num_added_space + i, " ")  # insert before
                num_added_space += 2
        clean_text = ''.join(tmp_words).strip()
        parser_out = self.stanza_parser(clean_text)

        idx_to_word = {i + 1: w for i, w in enumerate(words)}

        vocab_nodes = {}
        vocab_idx_offset = 0
        for sentence in parser_out.sentences:
            num_nodes_in_current_sentence = 0
            for vocab_node in sentence.words:
                num_nodes_in_current_sentence += 1
                vocab_idx = vocab_node.id + vocab_idx_offset
                vocab_text = vocab_node.text.replace(" ", "")  # delete blank in vocab
                vocab_nodes[vocab_idx] = vocab_text
            vocab_idx_offset += num_nodes_in_current_sentence

        # start vocab-to-word alignment
        vocab_to_word = {}
        current_word_idx = 1
        for vocab_i in vocab_nodes.keys():
            vocab_to_word[vocab_i] = []
            for w_in_vocab_i in vocab_nodes[vocab_i]:
                if w_in_vocab_i != idx_to_word[current_word_idx]:
                    raise ValueError("Word Mismatch!")
                vocab_to_word[vocab_i].append(current_word_idx)  # add a path (vocab_node_idx, word_global_idx)
                current_word_idx += 1

        # then we compute the vocab-level edges
        if len(parser_out.sentences) > 5:
            print("Detect more than 5 input sentence! pls check whether the sentence is too long!")
        vocab_level_source_id, vocab_level_dest_id = [], []
        vocab_level_edge_types = []
        sentences_heads = []
        vocab_id_offset = 0
        # get forward edges
        for s in parser_out.sentences:
            for w in s.words:
                w_idx = w.id + vocab_id_offset  # it starts from 1, just same as binarizer
                w_dest_idx = w.head + vocab_id_offset
                if w.head == 0:
                    sentences_heads.append(w_idx)
                    continue
                vocab_level_source_id.append(w_idx)
                vocab_level_dest_id.append(w_dest_idx)
            vocab_id_offset += len(s.words)
        vocab_level_edge_types += [0] * len(vocab_level_source_id)
        num_vocab = vocab_id_offset

        # optional: get backward edges
        if enable_backward_edge:
            back_source, back_dest = deepcopy(vocab_level_dest_id), deepcopy(vocab_level_source_id)
            vocab_level_source_id += back_source
            vocab_level_dest_id += back_dest
            vocab_level_edge_types += [1] * len(back_source)

        # optional: get inter-sentence edges if num_sentences > 1
        inter_sentence_source, inter_sentence_dest = [], []
        if enable_inter_sentence_edge and len(sentences_heads) > 1:
            def get_full_graph_edges(nodes):
                tmp_edges = []
                for i, node_i in enumerate(nodes):
                    for j, node_j in enumerate(nodes):
                        if i == j:
                            continue
                        tmp_edges.append((node_i, node_j))
                return tmp_edges

            tmp_edges = get_full_graph_edges(sentences_heads)
            for (source, dest) in tmp_edges:
                inter_sentence_source.append(source)
                inter_sentence_dest.append(dest)
            vocab_level_source_id += inter_sentence_source
            vocab_level_dest_id += inter_sentence_dest
            vocab_level_edge_types += [3] * len(inter_sentence_source)

        if sequential_edge:
            seq_source, seq_dest = list(range(1, num_vocab)) + list(range(num_vocab, 0, -1)), \
                                   list(range(2, num_vocab + 1)) + list(range(num_vocab - 1, -1, -1))
            vocab_level_source_id += seq_source
            vocab_level_dest_id += seq_dest
            vocab_level_edge_types += [4] * (num_vocab - 1) + [5] * (num_vocab - 1)

        # Then, we use the vocab-level edges and the vocab-to-word path, to construct the word-level graph
        num_word = len(words)
        source_id, dest_id, edge_types = [], [], []
        for (vocab_start, vocab_end, vocab_edge_type) in zip(vocab_level_source_id, vocab_level_dest_id,
                                                             vocab_level_edge_types):
            # connect the first word in the vocab
            word_start = min(vocab_to_word[vocab_start])
            word_end = min(vocab_to_word[vocab_end])
            source_id.append(word_start)
            dest_id.append(word_end)
            edge_types.append(vocab_edge_type)

        # sequential connection in words
        for word_indices_in_v in vocab_to_word.values():
            for i, word_idx in enumerate(word_indices_in_v):
                if i + 1 < len(word_indices_in_v):
                    source_id.append(word_idx)
                    dest_id.append(word_idx + 1)
                    edge_types.append(4)
                if i - 1 >= 0:
                    source_id.append(word_idx)
                    dest_id.append(word_idx - 1)
                    edge_types.append(5)

        # optional: get recurrent edges
        if enable_recur_edge:
            recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1))
            source_id += recur_source
            dest_id += recur_dest
            edge_types += [2] * len(recur_source)

        # add <BOS> and <EOS>
        source_id += [0, num_word + 1, 1, num_word]
        dest_id += [1, num_word, 0, num_word + 1]
        edge_types += [4, 4, 5, 5]  # 4 represents sequentially forward, 5 is sequential backward

        edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id))
        dgl_graph = dgl.graph(edges)
        assert dgl_graph.num_edges() == len(edge_types)
        return dgl_graph, torch.LongTensor(edge_types)

    def _parse_en(self, clean_sentence, enable_backward_edge=True, enable_recur_edge=True,
                  enable_inter_sentence_edge=True, sequential_edge=False, consider_bos_for_index=True):
        """
        clean_sentence: <str>, each word or punctuation should be separated by one blank.
        """
        edge_types = []  # required for gated graph neural network
        clean_sentence = clean_sentence.strip()
        if clean_sentence.endswith((" .", " ,", " ;", " :", " ?", " !")):
            clean_sentence = clean_sentence[:-2]
        if clean_sentence.startswith(". "):
            clean_sentence = clean_sentence[2:]
        parser_out = self.stanza_parser(clean_sentence)
        if len(parser_out.sentences) > 5:
            print("Detect more than 5 input sentence! pls check whether the sentence is too long!")
            print(clean_sentence)
        source_id, dest_id = [], []
        sentences_heads = []
        word_id_offset = 0
        # get forward edges
        for s in parser_out.sentences:
            for w in s.words:
                w_idx = w.id + word_id_offset  # it starts from 1, just same as binarizer
                w_dest_idx = w.head + word_id_offset
                if w.head == 0:
                    sentences_heads.append(w_idx)
                    continue
                source_id.append(w_idx)
                dest_id.append(w_dest_idx)
            word_id_offset += len(s.words)
        num_word = word_id_offset
        edge_types += [0] * len(source_id)

        # optional: get backward edges
        if enable_backward_edge:
            back_source, back_dest = deepcopy(dest_id), deepcopy(source_id)
            source_id += back_source
            dest_id += back_dest
            edge_types += [1] * len(back_source)

        # optional: get recurrent edges
        if enable_recur_edge:
            recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1))
            source_id += recur_source
            dest_id += recur_dest
            edge_types += [2] * len(recur_source)

        # optional: get inter-sentence edges if num_sentences > 1
        inter_sentence_source, inter_sentence_dest = [], []
        if enable_inter_sentence_edge and len(sentences_heads) > 1:
            def get_full_graph_edges(nodes):
                tmp_edges = []
                for i, node_i in enumerate(nodes):
                    for j, node_j in enumerate(nodes):
                        if i == j:
                            continue
                        tmp_edges.append((node_i, node_j))
                return tmp_edges

            tmp_edges = get_full_graph_edges(sentences_heads)
            for (source, dest) in tmp_edges:
                inter_sentence_source.append(source)
                inter_sentence_dest.append(dest)
            source_id += inter_sentence_source
            dest_id += inter_sentence_dest
            edge_types += [3] * len(inter_sentence_source)

        # add <BOS> and <EOS>
        source_id += [0, num_word + 1, 1, num_word]
        dest_id += [1, num_word, 0, num_word + 1]
        edge_types += [4, 4, 5, 5]  # 4 represents sequentially forward, 5 is sequential backward

        # optional: sequential edge
        if sequential_edge:
            seq_source, seq_dest = list(range(1, num_word)) + list(range(num_word, 0, -1)), \
                                   list(range(2, num_word + 1)) + list(range(num_word - 1, -1, -1))
            source_id += seq_source
            dest_id += seq_dest
            edge_types += [4] * (num_word - 1) + [5] * (num_word - 1)
        if consider_bos_for_index:
            edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id))
        else:
            edges = (torch.LongTensor(source_id) - 1, torch.LongTensor(dest_id) - 1)
        dgl_graph = dgl.graph(edges)
        assert dgl_graph.num_edges() == len(edge_types)
        return dgl_graph, torch.LongTensor(edge_types)


def plot_dgl_sentence_graph(dgl_graph, labels):
    """
    labels = {idx: word for idx,word in enumerate(sentence.split(" ")) }
    """
    import matplotlib.pyplot as plt
    nx_graph = dgl_graph.to_networkx()
    pos = nx.random_layout(nx_graph)
    nx.draw(nx_graph, pos, with_labels=False)
    nx.draw_networkx_labels(nx_graph, pos, labels)
    plt.show()

if __name__ == '__main__':

    # Unit Test for Chinese Graph Builder
    parser = Sentence2GraphParser("zh")
    text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.'
    words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',', '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>']
    ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', 'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|',
                'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>']
    graph1, etypes1 = parser.parse(text1, words, ph_words)
    plot_dgl_sentence_graph(graph1, {i: w for i, w in enumerate(ph_words)})

    # Unit Test for English Graph Builder
    parser = Sentence2GraphParser("en")
    text2 = "I love you . You love me . Mixue ice-scream and tea ."
    graph2, etypes2 = parser.parse(text2)
    plot_dgl_sentence_graph(graph2, {i: w for i, w in enumerate(("<BOS> " + text2 + " <EOS>").split(" "))})