plenz commited on
Commit
6fe5f08
1 Parent(s): bdc9297

Upload model

Browse files
Files changed (5) hide show
  1. config.json +65 -0
  2. configuration_t5.py +145 -0
  3. modeling_t5.py +0 -0
  4. pytorch_model.bin +3 -0
  5. wrapper_functions.py +485 -0
config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "t5-large",
3
+ "architectures": [
4
+ "T5EncoderModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_t5.T5Config",
8
+ "AutoModel": "modeling_t5.T5EncoderModel"
9
+ },
10
+ "d_ff": 4096,
11
+ "d_kv": 64,
12
+ "d_model": 1024,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 1.0,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": false,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "glm-t5",
23
+ "n_positions": 512,
24
+ "num_decoder_layers": 24,
25
+ "num_heads": 16,
26
+ "num_layers": 24,
27
+ "output_past": true,
28
+ "pad_token_id": 0,
29
+ "relative_attention_max_distance": 128,
30
+ "relative_attention_num_additional_buckets": 3,
31
+ "relative_attention_num_buckets": 32,
32
+ "task_specific_params": {
33
+ "summarization": {
34
+ "early_stopping": true,
35
+ "length_penalty": 2.0,
36
+ "max_length": 200,
37
+ "min_length": 30,
38
+ "no_repeat_ngram_size": 3,
39
+ "num_beams": 4,
40
+ "prefix": "summarize: "
41
+ },
42
+ "translation_en_to_de": {
43
+ "early_stopping": true,
44
+ "max_length": 300,
45
+ "num_beams": 4,
46
+ "prefix": "translate English to German: "
47
+ },
48
+ "translation_en_to_fr": {
49
+ "early_stopping": true,
50
+ "max_length": 300,
51
+ "num_beams": 4,
52
+ "prefix": "translate English to French: "
53
+ },
54
+ "translation_en_to_ro": {
55
+ "early_stopping": true,
56
+ "max_length": 300,
57
+ "num_beams": 4,
58
+ "prefix": "translate English to Romanian: "
59
+ }
60
+ },
61
+ "torch_dtype": "float32",
62
+ "transformers_version": "4.27.3",
63
+ "use_cache": true,
64
+ "vocab_size": 32128
65
+ }
configuration_t5.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020, The T5 Authors and HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ T5 model configuration"""
16
+
17
+ from typing import Mapping
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ ## FROM-T5
26
+ # T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ # "t5-small": "https://huggingface.co/t5-small/resolve/main/config.json",
28
+ # "t5-base": "https://huggingface.co/t5-base/resolve/main/config.json",
29
+ # "t5-large": "https://huggingface.co/t5-large/resolve/main/config.json",
30
+ # "t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json",
31
+ # "t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json",
32
+ # }
33
+
34
+
35
+ class T5Config(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
38
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
39
+ configuration with the defaults will yield a similar configuration to that of the T5
40
+ [t5-small](https://huggingface.co/t5-small) architecture.
41
+
42
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
43
+ documentation from [`PretrainedConfig`] for more information.
44
+
45
+ Arguments:
46
+ vocab_size (`int`, *optional*, defaults to 32128):
47
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
48
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
49
+ d_model (`int`, *optional*, defaults to 512):
50
+ Size of the encoder layers and the pooler layer.
51
+ d_kv (`int`, *optional*, defaults to 64):
52
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
53
+ be defined as `num_heads * d_kv`.
54
+ d_ff (`int`, *optional*, defaults to 2048):
55
+ Size of the intermediate feed forward layer in each `T5Block`.
56
+ num_layers (`int`, *optional*, defaults to 6):
57
+ Number of hidden layers in the Transformer encoder.
58
+ num_decoder_layers (`int`, *optional*):
59
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
60
+ num_heads (`int`, *optional*, defaults to 8):
61
+ Number of attention heads for each attention layer in the Transformer encoder.
62
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
63
+ The number of buckets to use for each attention layer.
64
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
65
+ The maximum distance of the longer sequences for the bucket separation.
66
+ dropout_rate (`float`, *optional*, defaults to 0.1):
67
+ The ratio for all dropout layers.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
69
+ The epsilon used by the layer normalization layers.
70
+ initializer_factor (`float`, *optional*, defaults to 1):
71
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
72
+ testing).
73
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
74
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
75
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models).
78
+ """
79
+ model_type = "glm-t5"
80
+ keys_to_ignore_at_inference = ["past_key_values"]
81
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
82
+
83
+ def __init__(
84
+ self,
85
+ vocab_size=32128,
86
+ d_model=512,
87
+ d_kv=64,
88
+ d_ff=2048,
89
+ num_layers=6,
90
+ num_decoder_layers=None,
91
+ num_heads=8,
92
+ relative_attention_num_buckets=32,
93
+ relative_attention_max_distance=128,
94
+ dropout_rate=0.1,
95
+ layer_norm_epsilon=1e-6,
96
+ initializer_factor=1.0,
97
+ feed_forward_proj="relu",
98
+ is_encoder_decoder=True,
99
+ use_cache=True,
100
+ pad_token_id=0,
101
+ eos_token_id=1,
102
+ # GLM parameters
103
+ relative_attention_num_additional_buckets=0,
104
+ **kwargs,
105
+ ):
106
+ self.vocab_size = vocab_size
107
+ self.d_model = d_model
108
+ self.d_kv = d_kv
109
+ self.d_ff = d_ff
110
+ self.num_layers = num_layers
111
+ self.num_decoder_layers = (
112
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
113
+ ) # default = symmetry
114
+ self.num_heads = num_heads
115
+ self.relative_attention_num_buckets = relative_attention_num_buckets
116
+ self.relative_attention_max_distance = relative_attention_max_distance
117
+ self.dropout_rate = dropout_rate
118
+ self.layer_norm_epsilon = layer_norm_epsilon
119
+ self.initializer_factor = initializer_factor
120
+ self.feed_forward_proj = feed_forward_proj
121
+ self.use_cache = use_cache
122
+ self.relative_attention_num_additional_buckets = relative_attention_num_additional_buckets
123
+
124
+ act_info = self.feed_forward_proj.split("-")
125
+ self.dense_act_fn = act_info[-1]
126
+ self.is_gated_act = act_info[0] == "gated"
127
+
128
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
129
+ raise ValueError(
130
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
131
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
132
+ "'gated-gelu' or 'relu'"
133
+ )
134
+
135
+ # for backwards compatibility
136
+ if feed_forward_proj == "gated-gelu":
137
+ self.dense_act_fn = "gelu_new"
138
+
139
+ super().__init__(
140
+ pad_token_id=pad_token_id,
141
+ eos_token_id=eos_token_id,
142
+ is_encoder_decoder=is_encoder_decoder,
143
+ **kwargs,
144
+ )
145
+
modeling_t5.py ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1451e871757a6393353f120dbff57da48d49d3e1b59ee4806f38f83b1aeea598
3
+ size 1339823129
wrapper_functions.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial, cache
4
+ from argparse import Namespace
5
+ from typing import List, Tuple, Dict, Union, Optional
6
+ from itertools import chain
7
+ import random
8
+ from typing import Literal
9
+
10
+ from models.graph_T5.graph_t5 import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
11
+ from models.graph_T5.graph_t5.modeling_t5 import T5Attention
12
+ import models.graph_T5.graph_t5.modeling_t5
13
+
14
+ class Graph():
15
+ """
16
+ A graph class.
17
+ :param g: A list of tuples, where each tuple is a triple (head, r, tail).
18
+ """
19
+ def __init__(
20
+ self,
21
+ g: List[Tuple[str,str,str]] = []
22
+ ):
23
+ self.g = g
24
+ self.concepts = self.get_concepts() # list of all concepts in the graph
25
+ self.relations = self.get_relations() # list of all relations in the graph
26
+ self.relations_multiple = self.get_relations_multiple() # list of all relations in the graph, including duplicate relations
27
+
28
+ @property
29
+ def g(self) -> List[Tuple[str,str,str]]:
30
+ return self._g
31
+
32
+ @g.setter
33
+ def g(self, g: List[Tuple[str,str,str]]):
34
+ self._g = g
35
+
36
+ def num_triplets(self) -> int:
37
+ """
38
+ Get the number of triplets in the graph.
39
+ """
40
+ return len(self.g)
41
+
42
+ def get_concepts(self) -> List[str]:
43
+ """
44
+ Get the concepts in the graph.
45
+ """
46
+ concepts = list(set([triplet[i] for triplet in self.g for i in [0, 2]]))
47
+ concepts.sort() # not necessary but makes debugging easier
48
+ return concepts
49
+
50
+ def get_relations(self) -> List[str]:
51
+ """
52
+ Get the relations in the graph.
53
+ """
54
+ relations = list(set(self.get_relations_multiple()))
55
+ relations.sort() # not necessary but makes debugging easier
56
+ return relations
57
+
58
+ def get_relations_multiple(self) -> List[str]:
59
+ """
60
+ Get the relations in the graph, including duplicate relations.
61
+ """
62
+ relations = [triplet[1] for triplet in self.g]
63
+ return relations
64
+
65
+ def __str__(self):
66
+ out_str = '\n'.join([str(triplet) for triplet in self.g])
67
+ return out_str
68
+
69
+ class Data(Namespace):
70
+ def __init__(self, **kwargs):
71
+ super().__init__()
72
+ self.__dict__.update(kwargs)
73
+
74
+ def get_dummy_graph(num_triplets:int=3) -> Graph:
75
+ g = [
76
+ ("dog", "IsA", "animal"),
77
+ ("cat", "IsA", "animal"),
78
+ ("black poodle", "IsA", "dog"),
79
+ ("black cat", "IsA", "cat"),
80
+ ]
81
+ assert num_triplets <=4, "num_triplets must be <= 4"
82
+ g = g[:num_triplets]
83
+ g = Graph(g)
84
+ return g
85
+
86
+ def r2nl(r: str) -> str:
87
+ """
88
+ Convert a relation to a natural language string. Can be used to implement necessary changes in the data.
89
+ """
90
+ return r
91
+
92
+ def _get_str2tok(g:Graph, tokenizer: T5Tokenizer) -> dict[str, list[int]]:
93
+ """
94
+ Get a dictionary that maps strings to tokens.
95
+ """
96
+ # tokenize concepts and relations
97
+ c_tok = tokenizer([r2nl(c) for c in g.concepts], padding=False)['input_ids']
98
+ r_tok = tokenizer([r2nl(r) for r in g.relations], padding=False)['input_ids']
99
+
100
+ tokens = c_tok + r_tok
101
+ node_names = g.concepts + g.relations # these are not necessarily all nodes in the Levi Graph, as relations can occur more than once
102
+ assert len(tokens) == len(node_names), f"{len(tokens) = }, {len(node_names) = }"
103
+
104
+ # remove end-of-sequence token
105
+ tokens = [toks[:-1] if toks[-1] == tokenizer.eos_token_id else toks for toks in tokens]
106
+
107
+ # create a dictionary mapping concepts and relations to their tokenized forms
108
+ str2tok = {node: tok for node, tok in zip(node_names, tokens)}
109
+ str2tok['</s>'] = [tokenizer.eos_token_id]
110
+ return str2tok
111
+
112
+ def _get_graphT5_input_sequence(g:Graph, str2tok:dict, use_eos:bool) -> Tuple[list, dict]:
113
+ # get input sequence (i.e. sequence that will be fed into the model for this graph)
114
+ all_nodes = g.relations_multiple + g.concepts # list of all concepts and relations that will be in the final sequence (i.e. all nodes of the Levi Graph) # the order of nodes is first all relations (in the order that they appear in g.g), and then all concepts (in alphabetical order. though here the order is not important)
115
+
116
+ if use_eos:
117
+ all_nodes.append('</s>')
118
+
119
+ all_tokens = [str2tok[node] for node in all_nodes] # list of length #nodes, where each element is a list of token ids
120
+ indices = {node: [] for node in all_nodes} # dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts and are as long as the number of occurances of the relation in the graph for relations. # WARNING: this assumes that concepts and realtions have different names. This not always the case for REBEL. For concept_indices this is fixed.
121
+ num_relation_tokens = sum([len(token) for token in all_tokens[:len(g.relations_multiple)]]) # number of tokens that are relations
122
+ num_concept_tokens = sum([len(token) for token in all_tokens[len(g.relations_multiple):len(g.relations_multiple)+len(g.concepts)]]) # number of tokens that are concepts
123
+ num_eos_tokens = 1 if use_eos else 0
124
+
125
+ is_concept = torch.tensor([False] * num_relation_tokens + [True] * num_concept_tokens + [False] * num_eos_tokens, dtype=torch.bool) # tensor of length #nodes, where each element is True if the node is a concept and False if it is a relation
126
+ index_counter = 0
127
+ assert len(all_nodes) == len(all_tokens), (all_nodes, all_tokens)
128
+
129
+ for node, token in zip(all_nodes, all_tokens):
130
+ indices[node].append((index_counter, index_counter + len(token)))
131
+ # assert is_concept[index_counter:index_counter+len(token)].all() == (node in g.concepts), f"{is_concept = }, {node = }, {g.concepts = }, {index_counter = }, {len(token) = }, {is_concept[index_counter:index_counter+len(token)] = }"
132
+ index_counter += len(token)
133
+
134
+ concept_indices = {node: [indices[node][-1]] for node in g.concepts} # [-1] and reput in list in case relations have the same name as a concept (concepts are put in last).
135
+ sequence = torch.tensor(list(chain.from_iterable(all_tokens)), dtype=torch.long)
136
+ sequence = sequence.unsqueeze(0) # add batch dimension
137
+ is_concept = is_concept.unsqueeze(0) # add batch dimension
138
+ return sequence, indices, is_concept, concept_indices
139
+
140
+ def _get_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
141
+ ### get relative position of each node in the sequence, as well as the sparsity mask ###
142
+ # initialize relative position matrix)
143
+ relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long)
144
+ # initialize sparsity mask
145
+ sparsity_mask = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
146
+ # initialize use_additional_bucket
147
+ use_additional_bucket = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
148
+
149
+ # relative positions / sparsity within each node
150
+ for start, end in chain.from_iterable(indices.values()):
151
+ relative_position[start:end, start:end] = _get_relative_position(end-start)
152
+ sparsity_mask[start:end, start:end] = True
153
+
154
+ # relative position between nodes of the same triplet
155
+ relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph
156
+ for triplet in g.g:
157
+ pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index)
158
+ pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index)
159
+ pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index)
160
+
161
+ l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation
162
+
163
+ # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it is sufficiently fast.
164
+ for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens
165
+ for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
166
+ relative_position[ph, pr] = l_h - ih + ir
167
+ relative_position[pr, ph] = - (l_h - ih + ir)
168
+ sparsity_mask[ph, pr] = True
169
+ sparsity_mask[pr, ph] = True
170
+ for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
171
+ relative_position[ph, pt] = l_h - ih + l_r + it
172
+ relative_position[pt, ph] = - (l_h - ih + l_r + it)
173
+ sparsity_mask[ph, pt] = True
174
+ sparsity_mask[pt, ph] = True
175
+ for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
176
+ for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
177
+ relative_position[pr, pt] = l_r - ir + it
178
+ relative_position[pt, pr] = - (l_r - ir + it)
179
+ sparsity_mask[pr, pt] = True
180
+ sparsity_mask[pt, pr] = True
181
+
182
+ relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used
183
+
184
+ if use_eos:
185
+ assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
186
+ pos_eos = indices['</s>'][0] # position of head; tuple (start_index, end_index)
187
+ assert pos_eos[0] + 1 == pos_eos[1], pos_eos
188
+ pos_eos = pos_eos[0] # position of eos token
189
+
190
+ if eos == 'bidirectional':
191
+ relative_position[:, pos_eos] = +1e6
192
+ relative_position[pos_eos, :] = -1e6
193
+ relative_position[pos_eos, pos_eos] = 0
194
+ sparsity_mask[:, pos_eos] = True
195
+ sparsity_mask[pos_eos, :] = True
196
+ elif eos == 'unidirectional':
197
+ relative_position[:, pos_eos] = 1e6
198
+ relative_position[pos_eos, pos_eos] = 0
199
+ sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens
200
+ sparsity_mask[:, pos_eos] = True
201
+ else:
202
+ raise ValueError(f'{eos = } is not a valid option.')
203
+
204
+ relative_position = relative_position.unsqueeze(0) # add batch dimension
205
+ sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension
206
+ use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension
207
+ return relative_position, sparsity_mask, use_additional_bucket
208
+
209
+ def _get_global_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ ### get relative position of each node in the sequence, as well as the sparsity mask ###
211
+ # initialize relative position matrix)
212
+ # relative_position = torch.ones(size=(sequence_length, sequence_length), dtype=torch.long) * 1e6 # technically should be float('inf'), but it does not matter
213
+ relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long)
214
+ # initialize sparsity mask
215
+ sparsity_mask = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) # could switch to None, but then code has to be updated accordingly (in particular get_batch)
216
+ # initialize use_additional_bucket
217
+ use_additional_bucket = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool)
218
+
219
+ # relative positions / sparsity within each node
220
+ for start, end in chain.from_iterable(indices.values()):
221
+ relative_position[start:end, start:end] = _get_relative_position(end-start)
222
+ use_additional_bucket[start:end, start:end] = False
223
+
224
+ # relative position between nodes of the same triplet
225
+ relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph
226
+ for triplet in g.g:
227
+ pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index)
228
+ pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index)
229
+ pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index)
230
+
231
+ l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation
232
+
233
+ # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it works.
234
+ for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens
235
+ for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
236
+ relative_position[ph, pr] = l_h - ih + ir
237
+ relative_position[pr, ph] = - (l_h - ih + ir)
238
+ use_additional_bucket[ph, pr] = False
239
+ use_additional_bucket[pr, ph] = False
240
+ for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
241
+ relative_position[ph, pt] = l_h - ih + l_r + it
242
+ relative_position[pt, ph] = - (l_h - ih + l_r + it)
243
+ use_additional_bucket[ph, pt] = False
244
+ use_additional_bucket[pt, ph] = False
245
+ for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens
246
+ for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens
247
+ relative_position[pr, pt] = l_r - ir + it
248
+ relative_position[pt, pr] = - (l_r - ir + it)
249
+ use_additional_bucket[pr, pt] = False
250
+ use_additional_bucket[pt, pr] = False
251
+
252
+ relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used
253
+ if use_eos:
254
+ assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
255
+ pos_eos = indices['</s>'][0] # position of head; tuple (start_index, end_index)
256
+ assert pos_eos[0] + 1 == pos_eos[1], pos_eos
257
+ pos_eos = pos_eos[0] # position of eos token
258
+
259
+ if eos == 'bidirectional':
260
+ relative_position[:, pos_eos] = +1e6
261
+ relative_position[pos_eos, :] = -1e6
262
+ relative_position[pos_eos, pos_eos] = 0
263
+ sparsity_mask[:, pos_eos] = True
264
+ sparsity_mask[pos_eos, :] = True
265
+ use_additional_bucket[:, pos_eos] = False
266
+ use_additional_bucket[pos_eos, :] = False
267
+ elif eos == 'unidirectional':
268
+ relative_position[:, pos_eos] = 1e6
269
+ relative_position[pos_eos, pos_eos] = 0
270
+ sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens
271
+ sparsity_mask[:, pos_eos] = True
272
+ use_additional_bucket[:, pos_eos] = False
273
+ use_additional_bucket[pos_eos, :] = False
274
+ else:
275
+ raise ValueError(f'{eos = } is not a valid option.')
276
+
277
+ relative_position = relative_position.unsqueeze(0) # add batch dimension
278
+ sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension
279
+ use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension
280
+ return relative_position, sparsity_mask, use_additional_bucket
281
+
282
+ def graph_to_graphT5(g:Graph, tokenizer:T5Tokenizer, how:str, eos:str)->Data:
283
+ """
284
+ Convert a graph to a graphT5 input.
285
+ :param g: graph
286
+ :param tokenizer: tokenizer
287
+ :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
288
+ :param eos: end-of-sequence token. Can be `False` for not using an eos token. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph, with a relative position of positive infinity (from node to eos) or negative infinity (from eos to node). `unidirectional` means that the eos token is connected to every node in the graph with a relative position of positive infinity (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
289
+ """
290
+ if not isinstance(g, Graph):
291
+ g = Graph(g)
292
+ eos = str(eos)
293
+ assert eos in ['False', 'bidirectional', 'unidirectional'], f"{eos = } must be either 'False', 'bidirectional', or 'unidirectional'"
294
+ use_eos:bool = eos != 'False'
295
+
296
+ str2tok = _get_str2tok(g, tokenizer) # get a dictionary mapping concepts and relations to their tokenized forms
297
+
298
+ sequence, indices, is_concept, concept_indices = _get_graphT5_input_sequence(g, str2tok, use_eos) # get input sequence (i.e. sequence that will be fed into the model for this graph
299
+ sequence_length = sequence.shape[1]
300
+
301
+ if how == 'local':
302
+ relative_position, sparsity_mask, use_additional_bucket = _get_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
303
+ num_additional_buckets = 0 # lGLM does not use additional buckets
304
+ elif how == 'global':
305
+ relative_position, sparsity_mask, use_additional_bucket = _get_global_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
306
+ num_additional_buckets = 1 # gGLM uses 1 additional bucket for long-ranged G2G connections
307
+ else:
308
+ raise ValueError(f"how must be either 'local' or 'global', but is {how}")
309
+
310
+ input_ids = sequence
311
+
312
+ data = Data(input_ids=input_ids, relative_position=relative_position, sparsity_mask=sparsity_mask, use_additional_bucket=use_additional_bucket, indices=indices, is_concept=is_concept, concept_indices=concept_indices, num_additional_buckets=num_additional_buckets)
313
+
314
+ return data
315
+
316
+ @cache
317
+ def _get_relative_position(size):
318
+ return torch.tensor([[i - j for i in range(size)] for j in range(size)], dtype=torch.long)
319
+
320
+ def get_embedding(
321
+ sequence_embedding: torch.Tensor,
322
+ indices: Dict[str, List[Tuple[int, int]]],
323
+ concept: str,
324
+ embedding_aggregation: str = "mean",
325
+ ):
326
+ """
327
+ Returns the embedding of a concept.
328
+ :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
329
+ :param indices: dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts.
330
+ :param concept: the concept for which the embedding should be returned
331
+ :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
332
+ :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size).
333
+ """
334
+ assert concept in indices.keys(), f"{concept = } is not a node in the graph. {indices = }"
335
+ assert len(indices[concept]) == 1, f"{concept = } is not a concept, as concepts occur only once in the graph. {indices = }"
336
+
337
+ start, end = indices[concept][0]
338
+ sequence_embedding = sequence_embedding[start:end, :]
339
+ if embedding_aggregation == "mean":
340
+ return torch.mean(sequence_embedding, dim=0, keepdim=True)
341
+ elif embedding_aggregation == "seq":
342
+ return sequence_embedding
343
+ else:
344
+ raise NotImplementedError(f"{embedding_aggregation = } is not supported. Use either 'mean' or 'seq'.")
345
+
346
+ def add_text_to_graph_data(data, text, tokenizer, use_text):
347
+ if use_text in {'False', '', False, None}:
348
+ return None
349
+
350
+ text_seq = torch.tensor(tokenizer(text, padding=False)['input_ids']).unsqueeze(0)
351
+ new_input_ids = torch.cat([data.input_ids, text_seq], dim=1)
352
+
353
+ old_seq_len = data.input_ids.shape[1]
354
+ text_seq_len = text_seq.shape[1]
355
+ new_seq_len = new_input_ids.shape[1]
356
+
357
+ new_is_graph = torch.zeros(size=(1, new_seq_len), dtype=torch.bool)
358
+ new_is_graph[:, :old_seq_len] = True
359
+
360
+ if data.relative_position is None: # sequence transformer
361
+ assert data.sparsity_mask is None
362
+ assert data.use_additional_bucket is None
363
+ data.input_ids = new_input_ids
364
+ data.is_graph = new_is_graph
365
+ return None
366
+
367
+ new_relative_position = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.relative_position.dtype)
368
+ new_relative_position[:, :old_seq_len, :old_seq_len] = data.relative_position
369
+ new_relative_position[:, old_seq_len:, old_seq_len:] = _get_relative_position(text_seq_len)
370
+
371
+ new_sparsity_mask = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.sparsity_mask.dtype)
372
+ new_sparsity_mask[:, :old_seq_len, :old_seq_len] = data.sparsity_mask
373
+ new_sparsity_mask[:, old_seq_len:, old_seq_len:] = True
374
+
375
+ new_use_additional_bucket = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.use_additional_bucket.dtype)
376
+ new_use_additional_bucket[:, :old_seq_len, :old_seq_len] = data.use_additional_bucket
377
+ new_use_additional_bucket[:, old_seq_len:, old_seq_len:] = False # could change that if we want T2T and local G2G relations to be learned separately
378
+
379
+ if use_text in {'FullyConnected', True}:
380
+ new_sparsity_mask[:, old_seq_len:, :old_seq_len] = True
381
+ new_sparsity_mask[:, :old_seq_len, old_seq_len:] = True
382
+
383
+ new_use_additional_bucket[:, old_seq_len:, :old_seq_len] = True
384
+ new_use_additional_bucket[:, :old_seq_len, old_seq_len:] = True
385
+
386
+ new_relative_position[:, old_seq_len:, :old_seq_len] = data.num_additional_buckets
387
+ new_relative_position[:, :old_seq_len, old_seq_len:] = data.num_additional_buckets + 1
388
+
389
+ new_num_additional_buckets = data.num_additional_buckets + 2
390
+ else:
391
+ raise ValueError(f"unknown use_text {use_text} (type {type(use_text)})")
392
+
393
+ data.input_ids = new_input_ids
394
+ data.relative_position = new_relative_position
395
+ data.sparsity_mask = new_sparsity_mask
396
+ data.use_additional_bucket = new_use_additional_bucket
397
+ data.num_additional_buckets = new_num_additional_buckets
398
+ data.is_graph = new_is_graph
399
+ return None
400
+
401
+ class DataProcessor():
402
+ @staticmethod
403
+ def encode_graph(tokenizer, g:Union[Graph,list[tuple[str,str,str]]], text:Optional[str]=None, how:Literal['global', 'local']='global', eos:str="False")->Data:
404
+ """
405
+ convert graph to suitable input for the model.
406
+ :param tokenizer: tokenizer
407
+ :param g: graph
408
+ :param text: text to add to the graph. Can be None if no text should be added.
409
+ :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
410
+ :param eos: end-of-sequence token. Can be `False` for not using an eos token. This is the method used in the paper. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph. `unidirectional` means that the eos token is connected to every node in the graph (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
411
+ :return: Data object
412
+ """
413
+ if not isinstance(g, Graph):
414
+ g = Graph(g)
415
+ data = graph_to_graphT5(g, tokenizer, how, eos)
416
+ if text is not None:
417
+ add_text_to_graph_data(data, text, tokenizer, use_text=True)
418
+ return data
419
+
420
+ @staticmethod
421
+ def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict:
422
+ """
423
+ converts list of data instances to batched inputs for GLM forward call.
424
+ :param datas: list of Data instances
425
+ :param max_seq_len: maximum sequence length
426
+ :param tokenizer: tokenizer
427
+ :param device: device
428
+ :return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
429
+ """
430
+ current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
431
+ if max_seq_len is None:
432
+ max_seq_len = current_max_seq_len
433
+ else:
434
+ max_seq_len = min(max_seq_len, current_max_seq_len)
435
+
436
+ if data_instances[0].relative_position is None:
437
+ assert data_instances[0].sparsity_mask is None
438
+ assert data_instances[0].use_additional_bucket is None
439
+ is_sequence_transformer = True
440
+ else:
441
+ assert data_instances[0].sparsity_mask is not None
442
+ assert data_instances[0].use_additional_bucket is not None
443
+ is_sequence_transformer = False
444
+
445
+ # intialize tensors
446
+ input_ids = torch.ones((len(data_instances), max_seq_len), dtype=torch.long, device=device) * tokenizer.pad_token_id
447
+ if is_sequence_transformer:
448
+ relative_position = None
449
+ sparsity_mask = None
450
+ use_additional_bucket = None
451
+ else:
452
+ relative_position = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.long, device=device)
453
+ sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
454
+ use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
455
+
456
+ # fill tensors
457
+ for i, data in enumerate(data_instances):
458
+ instance_len = min(data.input_ids.shape[1], max_seq_len)
459
+ input_ids[i, :instance_len] = data.input_ids[:, :instance_len]
460
+ if not is_sequence_transformer:
461
+ relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
462
+ sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
463
+ use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]
464
+
465
+ model_input = {
466
+ 'input_ids': input_ids,
467
+ 'relative_position': relative_position,
468
+ 'sparsity_mask': sparsity_mask,
469
+ 'use_additional_bucket': use_additional_bucket,
470
+ **kwargs
471
+ }
472
+ return model_input
473
+
474
+ @staticmethod
475
+ def get_embedding(sequence_embedding:torch.Tensor, indices:Dict[str,List[Tuple[int, int]]], concept:str, embedding_aggregation:str="mean"):
476
+ """
477
+ Returns embedding of a concept.
478
+ :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
479
+ :param indices: dictionary mapping each node to its start- and end-index in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. indices is part of the Data object.
480
+ :param concept: the concept for which the embedding should be returned.
481
+ :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
482
+ :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size).
483
+ """
484
+ return get_embedding(sequence_embedding, indices, concept, embedding_aggregation)
485
+