|
import torch |
|
import torch.nn as nn |
|
from functools import partial, cache |
|
from argparse import Namespace |
|
from typing import List, Tuple, Dict, Union, Optional |
|
from itertools import chain |
|
import random |
|
from typing import Literal |
|
|
|
from transformers import T5Tokenizer |
|
|
|
class Graph(): |
|
""" |
|
A graph class. |
|
:param g: A list of tuples, where each tuple is a triple (head, r, tail). |
|
""" |
|
def __init__( |
|
self, |
|
g: List[Tuple[str,str,str]] = [] |
|
): |
|
self.g = g |
|
self.concepts = self.get_concepts() |
|
self.relations = self.get_relations() |
|
self.relations_multiple = self.get_relations_multiple() |
|
|
|
@property |
|
def g(self) -> List[Tuple[str,str,str]]: |
|
return self._g |
|
|
|
@g.setter |
|
def g(self, g: List[Tuple[str,str,str]]): |
|
self._g = g |
|
|
|
def num_triplets(self) -> int: |
|
""" |
|
Get the number of triplets in the graph. |
|
""" |
|
return len(self.g) |
|
|
|
def get_concepts(self) -> List[str]: |
|
""" |
|
Get the concepts in the graph. |
|
""" |
|
concepts = list(set([triplet[i] for triplet in self.g for i in [0, 2]])) |
|
concepts.sort() |
|
return concepts |
|
|
|
def get_relations(self) -> List[str]: |
|
""" |
|
Get the relations in the graph. |
|
""" |
|
relations = list(set(self.get_relations_multiple())) |
|
relations.sort() |
|
return relations |
|
|
|
def get_relations_multiple(self) -> List[str]: |
|
""" |
|
Get the relations in the graph, including duplicate relations. |
|
""" |
|
relations = [triplet[1] for triplet in self.g] |
|
return relations |
|
|
|
def __str__(self): |
|
out_str = '\n'.join([str(triplet) for triplet in self.g]) |
|
return out_str |
|
|
|
class Data(Namespace): |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
self.__dict__.update(kwargs) |
|
|
|
def get_dummy_graph(num_triplets:int=3) -> Graph: |
|
g = [ |
|
("dog", "IsA", "animal"), |
|
("cat", "IsA", "animal"), |
|
("black poodle", "IsA", "dog"), |
|
("black cat", "IsA", "cat"), |
|
] |
|
assert num_triplets <=4, "num_triplets must be <= 4" |
|
g = g[:num_triplets] |
|
g = Graph(g) |
|
return g |
|
|
|
def r2nl(r: str) -> str: |
|
""" |
|
Convert a relation to a natural language string. Can be used to implement necessary changes in the data. |
|
""" |
|
return r |
|
|
|
def _get_str2tok(g:Graph, tokenizer: T5Tokenizer) -> dict[str, list[int]]: |
|
""" |
|
Get a dictionary that maps strings to tokens. |
|
""" |
|
|
|
c_tok = tokenizer([r2nl(c) for c in g.concepts], padding=False)['input_ids'] |
|
r_tok = tokenizer([r2nl(r) for r in g.relations], padding=False)['input_ids'] |
|
|
|
tokens = c_tok + r_tok |
|
node_names = g.concepts + g.relations |
|
assert len(tokens) == len(node_names), f"{len(tokens) = }, {len(node_names) = }" |
|
|
|
|
|
tokens = [toks[:-1] if toks[-1] == tokenizer.eos_token_id else toks for toks in tokens] |
|
|
|
|
|
str2tok = {node: tok for node, tok in zip(node_names, tokens)} |
|
str2tok['</s>'] = [tokenizer.eos_token_id] |
|
return str2tok |
|
|
|
def _get_graphT5_input_sequence(g:Graph, str2tok:dict, use_eos:bool) -> Tuple[list, dict]: |
|
|
|
all_nodes = g.relations_multiple + g.concepts |
|
|
|
if use_eos: |
|
all_nodes.append('</s>') |
|
|
|
all_tokens = [str2tok[node] for node in all_nodes] |
|
indices = {node: [] for node in all_nodes} |
|
num_relation_tokens = sum([len(token) for token in all_tokens[:len(g.relations_multiple)]]) |
|
num_concept_tokens = sum([len(token) for token in all_tokens[len(g.relations_multiple):len(g.relations_multiple)+len(g.concepts)]]) |
|
num_eos_tokens = 1 if use_eos else 0 |
|
|
|
is_concept = torch.tensor([False] * num_relation_tokens + [True] * num_concept_tokens + [False] * num_eos_tokens, dtype=torch.bool) |
|
index_counter = 0 |
|
assert len(all_nodes) == len(all_tokens), (all_nodes, all_tokens) |
|
|
|
for node, token in zip(all_nodes, all_tokens): |
|
indices[node].append((index_counter, index_counter + len(token))) |
|
|
|
index_counter += len(token) |
|
|
|
concept_indices = {node: [indices[node][-1]] for node in g.concepts} |
|
sequence = torch.tensor(list(chain.from_iterable(all_tokens)), dtype=torch.long) |
|
sequence = sequence.unsqueeze(0) |
|
is_concept = is_concept.unsqueeze(0) |
|
return sequence, indices, is_concept, concept_indices |
|
|
|
def _get_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long) |
|
|
|
sparsity_mask = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool) |
|
|
|
use_additional_bucket = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool) |
|
|
|
|
|
for start, end in chain.from_iterable(indices.values()): |
|
relative_position[start:end, start:end] = _get_relative_position(end-start) |
|
sparsity_mask[start:end, start:end] = True |
|
|
|
|
|
relation_counter = {relation: 0 for relation in g.relations} |
|
for triplet in g.g: |
|
pos_h = indices[triplet[0]][0] |
|
pos_r = indices[triplet[1]][relation_counter[triplet[1]]] |
|
pos_t = indices[triplet[2]][0] |
|
|
|
l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] |
|
|
|
|
|
for ih, ph in enumerate(range(pos_h[0], pos_h[1])): |
|
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): |
|
relative_position[ph, pr] = l_h - ih + ir |
|
relative_position[pr, ph] = - (l_h - ih + ir) |
|
sparsity_mask[ph, pr] = True |
|
sparsity_mask[pr, ph] = True |
|
for it, pt in enumerate(range(pos_t[0], pos_t[1])): |
|
relative_position[ph, pt] = l_h - ih + l_r + it |
|
relative_position[pt, ph] = - (l_h - ih + l_r + it) |
|
sparsity_mask[ph, pt] = True |
|
sparsity_mask[pt, ph] = True |
|
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): |
|
for it, pt in enumerate(range(pos_t[0], pos_t[1])): |
|
relative_position[pr, pt] = l_r - ir + it |
|
relative_position[pt, pr] = - (l_r - ir + it) |
|
sparsity_mask[pr, pt] = True |
|
sparsity_mask[pt, pr] = True |
|
|
|
relation_counter[triplet[1]] += 1 |
|
|
|
if use_eos: |
|
assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1" |
|
pos_eos = indices['</s>'][0] |
|
assert pos_eos[0] + 1 == pos_eos[1], pos_eos |
|
pos_eos = pos_eos[0] |
|
|
|
if eos == 'bidirectional': |
|
relative_position[:, pos_eos] = +1e6 |
|
relative_position[pos_eos, :] = -1e6 |
|
relative_position[pos_eos, pos_eos] = 0 |
|
sparsity_mask[:, pos_eos] = True |
|
sparsity_mask[pos_eos, :] = True |
|
elif eos == 'unidirectional': |
|
relative_position[:, pos_eos] = 1e6 |
|
relative_position[pos_eos, pos_eos] = 0 |
|
sparsity_mask[pos_eos, :] = False |
|
sparsity_mask[:, pos_eos] = True |
|
else: |
|
raise ValueError(f'{eos = } is not a valid option.') |
|
|
|
relative_position = relative_position.unsqueeze(0) |
|
sparsity_mask = sparsity_mask.unsqueeze(0) |
|
use_additional_bucket = use_additional_bucket.unsqueeze(0) |
|
return relative_position, sparsity_mask, use_additional_bucket |
|
|
|
def _get_global_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long) |
|
|
|
sparsity_mask = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) |
|
|
|
use_additional_bucket = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) |
|
|
|
|
|
for start, end in chain.from_iterable(indices.values()): |
|
relative_position[start:end, start:end] = _get_relative_position(end-start) |
|
use_additional_bucket[start:end, start:end] = False |
|
|
|
|
|
relation_counter = {relation: 0 for relation in g.relations} |
|
for triplet in g.g: |
|
pos_h = indices[triplet[0]][0] |
|
pos_r = indices[triplet[1]][relation_counter[triplet[1]]] |
|
pos_t = indices[triplet[2]][0] |
|
|
|
l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] |
|
|
|
|
|
for ih, ph in enumerate(range(pos_h[0], pos_h[1])): |
|
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): |
|
relative_position[ph, pr] = l_h - ih + ir |
|
relative_position[pr, ph] = - (l_h - ih + ir) |
|
use_additional_bucket[ph, pr] = False |
|
use_additional_bucket[pr, ph] = False |
|
for it, pt in enumerate(range(pos_t[0], pos_t[1])): |
|
relative_position[ph, pt] = l_h - ih + l_r + it |
|
relative_position[pt, ph] = - (l_h - ih + l_r + it) |
|
use_additional_bucket[ph, pt] = False |
|
use_additional_bucket[pt, ph] = False |
|
for ir, pr in enumerate(range(pos_r[0], pos_r[1])): |
|
for it, pt in enumerate(range(pos_t[0], pos_t[1])): |
|
relative_position[pr, pt] = l_r - ir + it |
|
relative_position[pt, pr] = - (l_r - ir + it) |
|
use_additional_bucket[pr, pt] = False |
|
use_additional_bucket[pt, pr] = False |
|
|
|
relation_counter[triplet[1]] += 1 |
|
if use_eos: |
|
assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1" |
|
pos_eos = indices['</s>'][0] |
|
assert pos_eos[0] + 1 == pos_eos[1], pos_eos |
|
pos_eos = pos_eos[0] |
|
|
|
if eos == 'bidirectional': |
|
relative_position[:, pos_eos] = +1e6 |
|
relative_position[pos_eos, :] = -1e6 |
|
relative_position[pos_eos, pos_eos] = 0 |
|
sparsity_mask[:, pos_eos] = True |
|
sparsity_mask[pos_eos, :] = True |
|
use_additional_bucket[:, pos_eos] = False |
|
use_additional_bucket[pos_eos, :] = False |
|
elif eos == 'unidirectional': |
|
relative_position[:, pos_eos] = 1e6 |
|
relative_position[pos_eos, pos_eos] = 0 |
|
sparsity_mask[pos_eos, :] = False |
|
sparsity_mask[:, pos_eos] = True |
|
use_additional_bucket[:, pos_eos] = False |
|
use_additional_bucket[pos_eos, :] = False |
|
else: |
|
raise ValueError(f'{eos = } is not a valid option.') |
|
|
|
relative_position = relative_position.unsqueeze(0) |
|
sparsity_mask = sparsity_mask.unsqueeze(0) |
|
use_additional_bucket = use_additional_bucket.unsqueeze(0) |
|
return relative_position, sparsity_mask, use_additional_bucket |
|
|
|
def graph_to_graphT5(g:Graph, tokenizer:T5Tokenizer, how:str, eos:str)->Data: |
|
""" |
|
Convert a graph to a graphT5 input. |
|
:param g: graph |
|
:param tokenizer: tokenizer |
|
:param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively. |
|
:param eos: end-of-sequence token. Can be `False` for not using an eos token. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph, with a relative position of positive infinity (from node to eos) or negative infinity (from eos to node). `unidirectional` means that the eos token is connected to every node in the graph with a relative position of positive infinity (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM |
|
""" |
|
if not isinstance(g, Graph): |
|
g = Graph(g) |
|
eos = str(eos) |
|
assert eos in ['False', 'bidirectional', 'unidirectional'], f"{eos = } must be either 'False', 'bidirectional', or 'unidirectional'" |
|
use_eos:bool = eos != 'False' |
|
|
|
str2tok = _get_str2tok(g, tokenizer) |
|
|
|
sequence, indices, is_concept, concept_indices = _get_graphT5_input_sequence(g, str2tok, use_eos) |
|
sequence_length = sequence.shape[1] |
|
|
|
if how == 'local': |
|
relative_position, sparsity_mask, use_additional_bucket = _get_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos) |
|
num_additional_buckets = 0 |
|
elif how == 'global': |
|
relative_position, sparsity_mask, use_additional_bucket = _get_global_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos) |
|
num_additional_buckets = 1 |
|
else: |
|
raise ValueError(f"how must be either 'local' or 'global', but is {how}") |
|
|
|
input_ids = sequence |
|
|
|
data = Data(input_ids=input_ids, relative_position=relative_position, sparsity_mask=sparsity_mask, use_additional_bucket=use_additional_bucket, indices=indices, is_concept=is_concept, concept_indices=concept_indices, num_additional_buckets=num_additional_buckets) |
|
|
|
return data |
|
|
|
@cache |
|
def _get_relative_position(size): |
|
return torch.tensor([[i - j for i in range(size)] for j in range(size)], dtype=torch.long) |
|
|
|
def get_embedding( |
|
sequence_embedding: torch.Tensor, |
|
indices: Dict[str, List[Tuple[int, int]]], |
|
concept: str, |
|
embedding_aggregation: str = "mean", |
|
): |
|
""" |
|
Returns the embedding of a concept. |
|
:param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size) |
|
:param indices: dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. |
|
:param concept: the concept for which the embedding should be returned |
|
:param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept. |
|
:return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size). |
|
""" |
|
assert concept in indices.keys(), f"{concept = } is not a node in the graph. {indices = }" |
|
assert len(indices[concept]) == 1, f"{concept = } is not a concept, as concepts occur only once in the graph. {indices = }" |
|
|
|
start, end = indices[concept][0] |
|
sequence_embedding = sequence_embedding[start:end, :] |
|
if embedding_aggregation == "mean": |
|
return torch.mean(sequence_embedding, dim=0, keepdim=True) |
|
elif embedding_aggregation == "seq": |
|
return sequence_embedding |
|
else: |
|
raise NotImplementedError(f"{embedding_aggregation = } is not supported. Use either 'mean' or 'seq'.") |
|
|
|
def add_text_to_graph_data(data, text, tokenizer, use_text): |
|
if use_text in {'False', '', False, None}: |
|
return None |
|
|
|
text_seq = torch.tensor(tokenizer(text, padding=False)['input_ids']).unsqueeze(0) |
|
new_input_ids = torch.cat([data.input_ids, text_seq], dim=1) |
|
|
|
old_seq_len = data.input_ids.shape[1] |
|
text_seq_len = text_seq.shape[1] |
|
new_seq_len = new_input_ids.shape[1] |
|
|
|
new_is_graph = torch.zeros(size=(1, new_seq_len), dtype=torch.bool) |
|
new_is_graph[:, :old_seq_len] = True |
|
|
|
if data.relative_position is None: |
|
assert data.sparsity_mask is None |
|
assert data.use_additional_bucket is None |
|
data.input_ids = new_input_ids |
|
data.is_graph = new_is_graph |
|
return None |
|
|
|
new_relative_position = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.relative_position.dtype) |
|
new_relative_position[:, :old_seq_len, :old_seq_len] = data.relative_position |
|
new_relative_position[:, old_seq_len:, old_seq_len:] = _get_relative_position(text_seq_len) |
|
|
|
new_sparsity_mask = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.sparsity_mask.dtype) |
|
new_sparsity_mask[:, :old_seq_len, :old_seq_len] = data.sparsity_mask |
|
new_sparsity_mask[:, old_seq_len:, old_seq_len:] = True |
|
|
|
new_use_additional_bucket = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.use_additional_bucket.dtype) |
|
new_use_additional_bucket[:, :old_seq_len, :old_seq_len] = data.use_additional_bucket |
|
new_use_additional_bucket[:, old_seq_len:, old_seq_len:] = False |
|
|
|
if use_text in {'FullyConnected', True}: |
|
new_sparsity_mask[:, old_seq_len:, :old_seq_len] = True |
|
new_sparsity_mask[:, :old_seq_len, old_seq_len:] = True |
|
|
|
new_use_additional_bucket[:, old_seq_len:, :old_seq_len] = True |
|
new_use_additional_bucket[:, :old_seq_len, old_seq_len:] = True |
|
|
|
new_relative_position[:, old_seq_len:, :old_seq_len] = data.num_additional_buckets |
|
new_relative_position[:, :old_seq_len, old_seq_len:] = data.num_additional_buckets + 1 |
|
|
|
new_num_additional_buckets = data.num_additional_buckets + 2 |
|
else: |
|
raise ValueError(f"unknown use_text {use_text} (type {type(use_text)})") |
|
|
|
data.input_ids = new_input_ids |
|
data.relative_position = new_relative_position |
|
data.sparsity_mask = new_sparsity_mask |
|
data.use_additional_bucket = new_use_additional_bucket |
|
data.num_additional_buckets = new_num_additional_buckets |
|
data.is_graph = new_is_graph |
|
return None |
|
|
|
class DataProcessor(): |
|
@staticmethod |
|
def encode_graph(tokenizer, g:Union[Graph,list[tuple[str,str,str]]], text:Optional[str]=None, how:Literal['global', 'local']='global', eos:str="False")->Data: |
|
""" |
|
convert graph to suitable input for the model. |
|
:param tokenizer: tokenizer |
|
:param g: graph |
|
:param text: text to add to the graph. Can be None if no text should be added. |
|
:param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively. |
|
:param eos: end-of-sequence token. Can be `False` for not using an eos token. This is the method used in the paper. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph. `unidirectional` means that the eos token is connected to every node in the graph (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM |
|
:return: Data object |
|
""" |
|
if not isinstance(g, Graph): |
|
g = Graph(g) |
|
data = graph_to_graphT5(g, tokenizer, how, eos) |
|
if text is not None: |
|
add_text_to_graph_data(data, text, tokenizer, use_text=True) |
|
return data |
|
|
|
@staticmethod |
|
def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict: |
|
""" |
|
converts list of data instances to batched inputs for GLM forward call. |
|
:param datas: list of Data instances |
|
:param max_seq_len: maximum sequence length |
|
:param tokenizer: tokenizer |
|
:param device: device |
|
:return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket' |
|
""" |
|
current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances]) |
|
if max_seq_len is None: |
|
max_seq_len = current_max_seq_len |
|
else: |
|
max_seq_len = min(max_seq_len, current_max_seq_len) |
|
|
|
if data_instances[0].relative_position is None: |
|
assert data_instances[0].sparsity_mask is None |
|
assert data_instances[0].use_additional_bucket is None |
|
is_sequence_transformer = True |
|
else: |
|
assert data_instances[0].sparsity_mask is not None |
|
assert data_instances[0].use_additional_bucket is not None |
|
is_sequence_transformer = False |
|
|
|
|
|
input_ids = torch.ones((len(data_instances), max_seq_len), dtype=torch.long, device=device) * tokenizer.pad_token_id |
|
if is_sequence_transformer: |
|
relative_position = None |
|
sparsity_mask = None |
|
use_additional_bucket = None |
|
else: |
|
relative_position = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.long, device=device) |
|
sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device) |
|
use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device) |
|
|
|
|
|
for i, data in enumerate(data_instances): |
|
instance_len = min(data.input_ids.shape[1], max_seq_len) |
|
input_ids[i, :instance_len] = data.input_ids[:, :instance_len] |
|
if not is_sequence_transformer: |
|
relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len] |
|
sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len] |
|
use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len] |
|
|
|
model_input = { |
|
'input_ids': input_ids, |
|
'relative_position': relative_position, |
|
'sparsity_mask': sparsity_mask, |
|
'use_additional_bucket': use_additional_bucket, |
|
**kwargs |
|
} |
|
return model_input |
|
|
|
@staticmethod |
|
def get_embedding(sequence_embedding:torch.Tensor, indices:Dict[str,List[Tuple[int, int]]], concept:str, embedding_aggregation:str="mean"): |
|
""" |
|
Returns embedding of a concept. |
|
:param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size) |
|
:param indices: dictionary mapping each node to its start- and end-index in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. indices is part of the Data object. |
|
:param concept: the concept for which the embedding should be returned. |
|
:param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept. |
|
:return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size). |
|
""" |
|
return get_embedding(sequence_embedding, indices, concept, embedding_aggregation) |
|
|
|
|