PLA-Net / model /model_concatenation.py
juliocesar-io's picture
Added initial app
b6f1234
raw
history blame
3.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from gcn_lib.sparse.torch_nn import MLP
from model.model import DeeperGCN
import numpy as np
import logging
class PLANet(torch.nn.Module):
def __init__(self, args,saliency=False):
super(PLANet, self).__init__()
# Args
self.args = args
# Molecule and protein networks
self.molecule_gcn = DeeperGCN(args, saliency=saliency)
self.target_gcn = DeeperGCN(args, is_prot=True)
# Individual modules' final embbeding size
output_molecule = args.hidden_channels
output_protein = args.hidden_channels_prot
# Concatenated embbeding size
Final_output = output_molecule + output_protein
# Overall model's final embbeding size
hidden_channels = args.hidden_channels
# Multiplier
if args.multi_concat:
self.multiplier_prot = torch.nn.Parameter(torch.zeros(hidden_channels))
self.multiplier_ligand = torch.nn.Parameter(torch.ones(hidden_channels))
elif self.args.MLP:
# MLP
hidden_channel = 64
channels_concat = [256, hidden_channel, hidden_channel, 128]
self.concatenation_gcn = MLP(channels_concat, norm=args.norm, last_lin=True)
# breakpoint()
indices = np.diag_indices(hidden_channel)
tensor_linear_layer = torch.zeros(hidden_channel, Final_output)
tensor_linear_layer[indices[0], indices[1]] = 1
self.concatenation_gcn[0].weight = torch.nn.Parameter(tensor_linear_layer)
self.concatenation_gcn[0].bias = torch.nn.Parameter(
torch.zeros(hidden_channel)
)
else:
# Concatenation Layer
self.concatenation_gcn = nn.Linear(Final_output, hidden_channels)
indices = np.diag_indices(output_molecule)
tensor_linear_layer = torch.zeros(hidden_channels, Final_output)
tensor_linear_layer[indices[0], indices[1]] = 1
self.concatenation_gcn.weight = torch.nn.Parameter(tensor_linear_layer)
self.concatenation_gcn.bias = torch.nn.Parameter(
torch.zeros(hidden_channels)
)
# Classification Layer
num_classes = args.nclasses
self.classification = nn.Linear(hidden_channels, num_classes)
def forward(self, molecule, target):
molecule_features = self.molecule_gcn(molecule)
target_features = self.target_gcn(target)
# Multiplier
if self.args.multi_concat:
All_features = (
target_features * self.multiplier_prot
+ molecule_features * self.multiplier_ligand
)
else:
# Concatenation of LM and PM modules
All_features = torch.cat((molecule_features, target_features), dim=1)
All_features = self.concatenation_gcn(All_features)
# Classification
classification = self.classification(All_features)
return classification
def print_params(self, epoch=None, final=False):
logging.info("======= Molecule GCN ========")
self.molecule_gcn.print_params(epoch)
logging.info("======= Protein GCN ========")
self.target_gcn.print_params(epoch)
if self.args.multi_concat:
sum_prot_multi = sum(self.multiplier_prot)
sum_lig_multi = sum(self.multiplier_ligand)
logging.info("Sumed prot multi: {}".format(sum_prot_multi))
logging.info("Sumed lig multi: {}".format(sum_lig_multi))