import torch import torch.nn as nn import torch.nn.functional as F from gcn_lib.sparse.torch_nn import MLP from model.model import DeeperGCN import numpy as np import logging class PLANet(torch.nn.Module): def __init__(self, args,saliency=False): super(PLANet, self).__init__() # Args self.args = args # Molecule and protein networks self.molecule_gcn = DeeperGCN(args, saliency=saliency) self.target_gcn = DeeperGCN(args, is_prot=True) # Individual modules' final embbeding size output_molecule = args.hidden_channels output_protein = args.hidden_channels_prot # Concatenated embbeding size Final_output = output_molecule + output_protein # Overall model's final embbeding size hidden_channels = args.hidden_channels # Multiplier if args.multi_concat: self.multiplier_prot = torch.nn.Parameter(torch.zeros(hidden_channels)) self.multiplier_ligand = torch.nn.Parameter(torch.ones(hidden_channels)) elif self.args.MLP: # MLP hidden_channel = 64 channels_concat = [256, hidden_channel, hidden_channel, 128] self.concatenation_gcn = MLP(channels_concat, norm=args.norm, last_lin=True) # breakpoint() indices = np.diag_indices(hidden_channel) tensor_linear_layer = torch.zeros(hidden_channel, Final_output) tensor_linear_layer[indices[0], indices[1]] = 1 self.concatenation_gcn[0].weight = torch.nn.Parameter(tensor_linear_layer) self.concatenation_gcn[0].bias = torch.nn.Parameter( torch.zeros(hidden_channel) ) else: # Concatenation Layer self.concatenation_gcn = nn.Linear(Final_output, hidden_channels) indices = np.diag_indices(output_molecule) tensor_linear_layer = torch.zeros(hidden_channels, Final_output) tensor_linear_layer[indices[0], indices[1]] = 1 self.concatenation_gcn.weight = torch.nn.Parameter(tensor_linear_layer) self.concatenation_gcn.bias = torch.nn.Parameter( torch.zeros(hidden_channels) ) # Classification Layer num_classes = args.nclasses self.classification = nn.Linear(hidden_channels, num_classes) def forward(self, molecule, target): molecule_features = self.molecule_gcn(molecule) target_features = self.target_gcn(target) # Multiplier if self.args.multi_concat: All_features = ( target_features * self.multiplier_prot + molecule_features * self.multiplier_ligand ) else: # Concatenation of LM and PM modules All_features = torch.cat((molecule_features, target_features), dim=1) All_features = self.concatenation_gcn(All_features) # Classification classification = self.classification(All_features) return classification def print_params(self, epoch=None, final=False): logging.info("======= Molecule GCN ========") self.molecule_gcn.print_params(epoch) logging.info("======= Protein GCN ========") self.target_gcn.print_params(epoch) if self.args.multi_concat: sum_prot_multi = sum(self.multiplier_prot) sum_lig_multi = sum(self.multiplier_ligand) logging.info("Sumed prot multi: {}".format(sum_prot_multi)) logging.info("Sumed lig multi: {}".format(sum_lig_multi))