Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool | |
from gcn_lib.sparse.torch_vertex import GENConv | |
from gcn_lib.sparse.torch_nn import norm_layer, MLP, MM_AtomEncoder | |
from model.model_encoder import AtomEncoder, BondEncoder | |
import logging | |
class DeeperGCN(torch.nn.Module): | |
def __init__(self, args, is_prot=False, saliency=False): | |
super(DeeperGCN, self).__init__() | |
# Set PM configuration | |
if is_prot: | |
self.num_layers = args.num_layers_prot | |
mlp_layers = args.mlp_layers_prot | |
hidden_channels = args.hidden_channels_prot | |
self.msg_norm = args.msg_norm_prot | |
learn_msg_scale = args.learn_msg_scale_prot | |
self.conv_encode_edge = args.conv_encode_edge_prot | |
# Set LM configuration | |
else: | |
self.num_layers = args.num_layers | |
mlp_layers = args.mlp_layers | |
hidden_channels = args.hidden_channels | |
self.msg_norm = args.msg_norm | |
learn_msg_scale = args.learn_msg_scale | |
self.conv_encode_edge = args.conv_encode_edge | |
# Set overall model configuration | |
self.dropout = args.dropout | |
self.block = args.block | |
self.add_virtual_node = args.add_virtual_node | |
self.training = True | |
self.args = args | |
num_classes = args.nclasses | |
conv = args.conv | |
aggr = args.gcn_aggr | |
t = args.t | |
self.learn_t = args.learn_t | |
p = args.p | |
self.learn_p = args.learn_p | |
norm = args.norm | |
graph_pooling = args.graph_pooling | |
# Print model parameters | |
print( | |
"The number of layers {}".format(self.num_layers), | |
"Aggr aggregation method {}".format(aggr), | |
"block: {}".format(self.block), | |
) | |
if self.block == "res+": | |
print("LN/BN->ReLU->GraphConv->Res") | |
elif self.block == "res": | |
print("GraphConv->LN/BN->ReLU->Res") | |
elif self.block == "dense": | |
raise NotImplementedError("To be implemented") | |
elif self.block == "plain": | |
print("GraphConv->LN/BN->ReLU") | |
else: | |
raise Exception("Unknown block Type") | |
self.gcns = torch.nn.ModuleList() | |
self.norms = torch.nn.ModuleList() | |
if self.add_virtual_node: | |
self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels) | |
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) | |
self.mlp_virtualnode_list = torch.nn.ModuleList() | |
for layer in range(self.num_layers - 1): | |
self.mlp_virtualnode_list.append(MLP([hidden_channels] * 3, norm=norm)) | |
# Set GCN layer configuration | |
for layer in range(self.num_layers): | |
if conv == "gen": | |
gcn = GENConv( | |
hidden_channels, | |
hidden_channels, | |
args, | |
aggr=aggr, | |
t=t, | |
learn_t=self.learn_t, | |
p=p, | |
learn_p=self.learn_p, | |
msg_norm=self.msg_norm, | |
learn_msg_scale=learn_msg_scale, | |
encode_edge=self.conv_encode_edge, | |
bond_encoder=True, | |
norm=norm, | |
mlp_layers=mlp_layers, | |
) | |
else: | |
raise Exception("Unknown Conv Type") | |
self.gcns.append(gcn) | |
self.norms.append(norm_layer(norm, hidden_channels)) | |
# Set embbeding layers | |
self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) | |
if saliency: | |
self.atom_encoder = MM_AtomEncoder(emb_dim=hidden_channels) | |
else: | |
self.atom_encoder = AtomEncoder(emb_dim=hidden_channels) | |
if not self.conv_encode_edge: | |
self.bond_encoder = BondEncoder(emb_dim=hidden_channels) | |
# Set type of pooling | |
if graph_pooling == "sum": | |
self.pool = global_add_pool | |
elif graph_pooling == "mean": | |
self.pool = global_mean_pool | |
elif graph_pooling == "max": | |
self.pool = global_max_pool | |
else: | |
raise Exception("Unknown Pool Type") | |
# Set classification layer | |
self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_classes) | |
def forward(self, input_batch, dropout=True, embeddings=False): | |
x = input_batch.x | |
edge_index = input_batch.edge_index | |
edge_attr = input_batch.edge_attr | |
batch = input_batch.batch | |
h = self.atom_encoder(x) | |
if self.add_virtual_node: | |
virtualnode_embedding = self.virtualnode_embedding( | |
torch.zeros(batch[-1].item() + 1) | |
.to(edge_index.dtype) | |
.to(edge_index.device) | |
) | |
h = h + virtualnode_embedding[batch] | |
if self.conv_encode_edge: | |
edge_emb = edge_attr | |
else: | |
edge_emb = self.bond_encoder(edge_attr) | |
if self.block == "res+": | |
h = self.gcns[0](h, edge_index, edge_emb) | |
for layer in range(1, self.num_layers): | |
h1 = self.norms[layer - 1](h) | |
h2 = F.relu(h1) | |
if dropout: | |
h2 = F.dropout(h2, p=self.dropout, training=self.training) | |
if self.add_virtual_node: | |
virtualnode_embedding_temp = ( | |
global_add_pool(h2, batch) + virtualnode_embedding | |
) | |
if dropout: | |
virtualnode_embedding = F.dropout( | |
self.mlp_virtualnode_list[layer - 1]( | |
virtualnode_embedding_temp | |
), | |
self.dropout, | |
training=self.training, | |
) | |
h2 = h2 + virtualnode_embedding[batch] | |
h = self.gcns[layer](h2, edge_index, edge_emb) + h | |
h = self.norms[self.num_layers - 1](h) | |
if dropout: | |
h = F.dropout(h, p=self.dropout, training=self.training) | |
elif self.block == "res": | |
h = F.relu(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) | |
h = F.dropout(h, p=self.dropout, training=self.training) | |
for layer in range(1, self.num_layers): | |
h1 = self.gcns[layer](h, edge_index, edge_emb) | |
h2 = self.norms[layer](h1) | |
h = F.relu(h2) + h | |
h = F.dropout(h, p=self.dropout, training=self.training) | |
elif self.block == "dense": | |
raise NotImplementedError("To be implemented") | |
elif self.block == "plain": | |
h = F.relu(self.norms[0](self.gcns[0](h, edge_index, edge_emb))) | |
h = F.dropout(h, p=self.dropout, training=self.training) | |
for layer in range(1, self.num_layers): | |
h1 = self.gcns[layer](h, edge_index, edge_emb) | |
h2 = self.norms[layer](h1) | |
if layer != (self.num_layers - 1): | |
h = F.relu(h2) | |
else: | |
h = h2 | |
h = F.dropout(h, p=self.dropout, training=self.training) | |
else: | |
raise Exception("Unknown block Type") | |
h_graph = self.pool(h, batch) | |
if self.args.use_prot or embeddings: | |
return h_graph | |
else: | |
return self.graph_pred_linear(h_graph) | |
def print_params(self, epoch=None, final=False): | |
if self.learn_t: | |
ts = [] | |
for gcn in self.gcns: | |
ts.append(gcn.t.item()) | |
if final: | |
print("Final t {}".format(ts)) | |
else: | |
logging.info("Epoch {}, t {}".format(epoch, ts)) | |
if self.learn_p: | |
ps = [] | |
for gcn in self.gcns: | |
ps.append(gcn.p.item()) | |
if final: | |
print("Final p {}".format(ps)) | |
else: | |
logging.info("Epoch {}, p {}".format(epoch, ps)) | |
if self.msg_norm: | |
ss = [] | |
for gcn in self.gcns: | |
ss.append(gcn.msg_norm.msg_scale.item()) | |
if final: | |
print("Final s {}".format(ss)) | |
else: | |
logging.info("Epoch {}, s {}".format(epoch, ss)) | |