File size: 3,703 Bytes
b6f1234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torch.nn.functional as F

from gcn_lib.sparse.torch_nn import MLP

from model.model import DeeperGCN

import numpy as np
import logging


class PLANet(torch.nn.Module):
    def __init__(self, args,saliency=False):
        super(PLANet, self).__init__() 

        # Args
        self.args = args
        # Molecule and protein networks
        self.molecule_gcn = DeeperGCN(args, saliency=saliency)
        self.target_gcn = DeeperGCN(args, is_prot=True)

        # Individual modules' final embbeding size
        output_molecule = args.hidden_channels
        output_protein = args.hidden_channels_prot
        # Concatenated embbeding size
        Final_output = output_molecule + output_protein
        # Overall model's final embbeding size
        hidden_channels = args.hidden_channels

        # Multiplier
        if args.multi_concat:
            self.multiplier_prot = torch.nn.Parameter(torch.zeros(hidden_channels))
            self.multiplier_ligand = torch.nn.Parameter(torch.ones(hidden_channels))
        elif self.args.MLP:
            # MLP
            hidden_channel = 64
            channels_concat = [256, hidden_channel, hidden_channel, 128]
            self.concatenation_gcn = MLP(channels_concat, norm=args.norm, last_lin=True)
            # breakpoint()
            indices = np.diag_indices(hidden_channel)
            tensor_linear_layer = torch.zeros(hidden_channel, Final_output)
            tensor_linear_layer[indices[0], indices[1]] = 1
            self.concatenation_gcn[0].weight = torch.nn.Parameter(tensor_linear_layer)
            self.concatenation_gcn[0].bias = torch.nn.Parameter(
                torch.zeros(hidden_channel)
            )
        else:
            # Concatenation Layer
            self.concatenation_gcn = nn.Linear(Final_output, hidden_channels)
            indices = np.diag_indices(output_molecule)
            tensor_linear_layer = torch.zeros(hidden_channels, Final_output)
            tensor_linear_layer[indices[0], indices[1]] = 1
            self.concatenation_gcn.weight = torch.nn.Parameter(tensor_linear_layer)
            self.concatenation_gcn.bias = torch.nn.Parameter(
                torch.zeros(hidden_channels)
            )

        # Classification Layer
        num_classes = args.nclasses
        self.classification = nn.Linear(hidden_channels, num_classes)

    def forward(self, molecule, target):

        molecule_features = self.molecule_gcn(molecule)
        target_features = self.target_gcn(target)
        # Multiplier
        if self.args.multi_concat:
            All_features = (
                target_features * self.multiplier_prot
                + molecule_features * self.multiplier_ligand
            )
        else:
            # Concatenation of LM and PM modules
            All_features = torch.cat((molecule_features, target_features), dim=1)
            All_features = self.concatenation_gcn(All_features)
        # Classification
        classification = self.classification(All_features)

        return classification

    def print_params(self, epoch=None, final=False):

        logging.info("======= Molecule GCN ========")
        self.molecule_gcn.print_params(epoch)
        logging.info("======= Protein GCN ========")
        self.target_gcn.print_params(epoch)
        if self.args.multi_concat:
            sum_prot_multi = sum(self.multiplier_prot)
            sum_lig_multi = sum(self.multiplier_ligand)
            logging.info("Sumed prot multi: {}".format(sum_prot_multi))
            logging.info("Sumed lig multi: {}".format(sum_lig_multi))