Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from gcn_lib.sparse.torch_nn import MLP | |
from model.model import DeeperGCN | |
import numpy as np | |
import logging | |
class PLANet(torch.nn.Module): | |
def __init__(self, args,saliency=False): | |
super(PLANet, self).__init__() | |
# Args | |
self.args = args | |
# Molecule and protein networks | |
self.molecule_gcn = DeeperGCN(args, saliency=saliency) | |
self.target_gcn = DeeperGCN(args, is_prot=True) | |
# Individual modules' final embbeding size | |
output_molecule = args.hidden_channels | |
output_protein = args.hidden_channels_prot | |
# Concatenated embbeding size | |
Final_output = output_molecule + output_protein | |
# Overall model's final embbeding size | |
hidden_channels = args.hidden_channels | |
# Multiplier | |
if args.multi_concat: | |
self.multiplier_prot = torch.nn.Parameter(torch.zeros(hidden_channels)) | |
self.multiplier_ligand = torch.nn.Parameter(torch.ones(hidden_channels)) | |
elif self.args.MLP: | |
# MLP | |
hidden_channel = 64 | |
channels_concat = [256, hidden_channel, hidden_channel, 128] | |
self.concatenation_gcn = MLP(channels_concat, norm=args.norm, last_lin=True) | |
# breakpoint() | |
indices = np.diag_indices(hidden_channel) | |
tensor_linear_layer = torch.zeros(hidden_channel, Final_output) | |
tensor_linear_layer[indices[0], indices[1]] = 1 | |
self.concatenation_gcn[0].weight = torch.nn.Parameter(tensor_linear_layer) | |
self.concatenation_gcn[0].bias = torch.nn.Parameter( | |
torch.zeros(hidden_channel) | |
) | |
else: | |
# Concatenation Layer | |
self.concatenation_gcn = nn.Linear(Final_output, hidden_channels) | |
indices = np.diag_indices(output_molecule) | |
tensor_linear_layer = torch.zeros(hidden_channels, Final_output) | |
tensor_linear_layer[indices[0], indices[1]] = 1 | |
self.concatenation_gcn.weight = torch.nn.Parameter(tensor_linear_layer) | |
self.concatenation_gcn.bias = torch.nn.Parameter( | |
torch.zeros(hidden_channels) | |
) | |
# Classification Layer | |
num_classes = args.nclasses | |
self.classification = nn.Linear(hidden_channels, num_classes) | |
def forward(self, molecule, target): | |
molecule_features = self.molecule_gcn(molecule) | |
target_features = self.target_gcn(target) | |
# Multiplier | |
if self.args.multi_concat: | |
All_features = ( | |
target_features * self.multiplier_prot | |
+ molecule_features * self.multiplier_ligand | |
) | |
else: | |
# Concatenation of LM and PM modules | |
All_features = torch.cat((molecule_features, target_features), dim=1) | |
All_features = self.concatenation_gcn(All_features) | |
# Classification | |
classification = self.classification(All_features) | |
return classification | |
def print_params(self, epoch=None, final=False): | |
logging.info("======= Molecule GCN ========") | |
self.molecule_gcn.print_params(epoch) | |
logging.info("======= Protein GCN ========") | |
self.target_gcn.print_params(epoch) | |
if self.args.multi_concat: | |
sum_prot_multi = sum(self.multiplier_prot) | |
sum_lig_multi = sum(self.multiplier_ligand) | |
logging.info("Sumed prot multi: {}".format(sum_prot_multi)) | |
logging.info("Sumed lig multi: {}".format(sum_lig_multi)) | |