File size: 8,945 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import torch
from torch import nn
from . import tasks, layers
from ultra.base_nbfnet import BaseNBFNet
class Ultra(nn.Module):
def __init__(self, rel_model_cfg, entity_model_cfg):
# kept that because super Ultra sounds cool
super(Ultra, self).__init__()
self.relation_model = RelNBFNet(**rel_model_cfg)
self.entity_model = EntityNBFNet(**entity_model_cfg)
def forward(self, data, batch):
# batch shape: (bs, 1+num_negs, 3)
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
query_rels = batch[:, 0, 2]
relation_representations = self.relation_model(data.relation_graph, query=query_rels)
score = self.entity_model(data, relation_representations, batch)
return score
# NBFNet to work on the graph of relations with 4 fundamental interactions
# Doesn't have the final projection MLP from hidden dim -> 1, returns all node representations
# of shape [bs, num_rel, hidden]
class RelNBFNet(BaseNBFNet):
def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs):
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(
layers.GeneralizedRelationalConv(
self.dims[i], self.dims[i + 1], num_relation,
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
self.activation, dependent=False)
)
if self.concat_hidden:
feature_dim = sum(hidden_dims) + input_dim
self.mlp = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, input_dim)
)
def bellmanford(self, data, h_index, separate_grad=False):
batch_size = len(h_index)
# initialize initial nodes (relations of interest in the batcj) with all ones
query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
index = h_index.unsqueeze(-1).expand_as(query)
# initial (boundary) condition - initialize all node states as zeros
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
#boundary = torch.zeros(data.num_nodes, *query.shape, device=h_index.device)
# Indicator function: by the scatter operation we put ones as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
output = self.mlp(output)
else:
output = hiddens[-1]
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, rel_graph, query):
# message passing and updated node representations (that are in fact relations)
output = self.bellmanford(rel_graph, h_index=query)["node_feature"] # (batch_size, num_nodes, hidden_dim)
return output
class EntityNBFNet(BaseNBFNet):
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
# dummy num_relation = 1 as we won't use it in the NBFNet layer
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(
layers.GeneralizedRelationalConv(
self.dims[i], self.dims[i + 1], num_relation,
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
self.activation, dependent=False, project_relations=True)
)
feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim
self.mlp = nn.Sequential()
mlp = []
for i in range(self.num_mlp_layers - 1):
mlp.append(nn.Linear(feature_dim, feature_dim))
mlp.append(nn.ReLU())
mlp.append(nn.Linear(feature_dim, 1))
self.mlp = nn.Sequential(*mlp)
def bellmanford(self, data, h_index, r_index, separate_grad=False):
batch_size = len(r_index)
# initialize queries (relation types of the given triples)
query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
index = h_index.unsqueeze(-1).expand_as(query)
# initial (boundary) condition - initialize all node states as zeros
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
# for visualization
if separate_grad:
edge_weight = edge_weight.clone().requires_grad_()
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
else:
output = torch.cat([hiddens[-1], node_query], dim=-1)
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, data, relation_representations, batch):
h_index, t_index, r_index = batch.unbind(-1)
# initial query representations are those from the relation graph
self.query = relation_representations
# initialize relations in each NBFNet layer (with uinque projection internally)
for layer in self.layers:
layer.relation = relation_representations
if self.training:
# Edge dropout in the training mode
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
# to make NBFNet iteration learn non-trivial paths
data = self.remove_easy_edges(data, h_index, t_index, r_index)
shape = h_index.shape
# turn all triples in a batch into a tail prediction mode
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
assert (h_index[:, [0]] == h_index).all()
assert (r_index[:, [0]] == r_index).all()
# message passing and updated node representations
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
feature = output["node_feature"]
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
# extract representations of tail entities from the updated node states
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
# probability logit for each tail node in the batch
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
score = self.mlp(feature).squeeze(-1)
return score.view(shape)
|