jinysun commited on
Commit
c5f2040
·
1 Parent(s): 5fdd261

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitignore +160 -0
  2. app.py +43 -0
  3. requirements.txt +18 -0
  4. run.py +102 -0
  5. train.py +454 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import rdkit
4
+ import streamlit_ketcher
5
+ from streamlit_ketcher import st_ketcher
6
+ import run
7
+
8
+ # Page setup
9
+ st.set_page_config(page_title="DeepDAP", page_icon="🔋", layout="wide")
10
+ st.title("🔋DeepDAP")
11
+
12
+ # Connect to the Google Sheet
13
+
14
+ url1= r"https://docs.google.com/spreadsheets/d/1AKkZS04VF3osFT36aNHIb4iUbV8D1uNfsldcpHXogj0/gviz/tq?tqx=out:csv&sheet=dap"
15
+ df1 = pd.read_csv(url1, dtype=str, encoding='utf-8')
16
+
17
+ text_search = st.text_input("🔍Search papers or molecules", value="")
18
+ m1 = df1["Donor_Name"].str.contains(text_search)
19
+ m2 = df1["reference"].str.contains(text_search)
20
+ m3 = df1["Acceptor_Name"].str.contains(text_search)
21
+ df_search = df1[m1 | m2|m3]
22
+ if text_search:
23
+ st.write(df_search)
24
+ st.download_button( "⬇️Download edited files as .csv", df_search.to_csv(), "df_search.csv", use_container_width=True)
25
+ edited_df = st.data_editor(df1, num_rows="dynamic")
26
+
27
+ st.download_button(
28
+ "⬇️ Download edited files as .csv", edited_df.to_csv(), "edited_df.csv", use_container_width=True
29
+ )
30
+
31
+ molecule = st.text_input("👨‍🔬Molecule")
32
+ smile_code = st_ketcher(molecule)
33
+ st.markdown("🏆New SMILES of edited molecules: {smile_code }")
34
+
35
+ acceptor= st.text_input("🎈SMILES of acceptor")
36
+
37
+ donor = st.text_input("🎈SMILES of donor")
38
+
39
+ try:
40
+ pce = run.smiles_aas_test( str(acceptor ), str(donor) )
41
+ st.markdown("⚡PCE: ``{pce}``")
42
+ except:
43
+ st.markdown("⚡PCE: None ")
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair
2
+ streamlit
3
+ streamlit-ketcher
4
+ torch
5
+ tqdm
6
+ transformers
7
+ pytorch_lightning
8
+ scipy
9
+ pandas
10
+ rdkit
11
+ scikit-learn
12
+ matplotlib
13
+ easydict
14
+ wandb
15
+ networkx
16
+ seaborn
17
+
18
+
run.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from transformers import AutoTokenizer
7
+
8
+ from util.utils import *
9
+
10
+ from tqdm import tqdm
11
+ from train import markerModel
12
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
14
+
15
+ device_count = torch.cuda.device_count()
16
+ device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")
17
+
18
+ device = torch.device('cpu')
19
+ d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
20
+ p_model_name = 'DeepChem/ChemBERTa-10M-MLM'
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(d_model_name)
23
+ prot_tokenizer = AutoTokenizer.from_pretrained(p_model_name)
24
+
25
+ #--biomarker Model
26
+ ##-- hyper param config file Load --##
27
+ config = load_hparams('config/predict.json')
28
+ config = DictX(config)
29
+ model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
30
+
31
+ # model = BiomarkerModel.load_from_checkpoint('./biomarker_bindingdb_train8595_pretopre/3477h3wf/checkpoints/epoch=30-step=7284.ckpt').to(device_biomarker)
32
+
33
+ model.eval()
34
+ model.freeze()
35
+
36
+ if device_biomarker.type == 'cuda':
37
+ model = torch.nn.DataParallel(model)
38
+
39
+ def get_biomarker(drug_inputs, prot_inputs):
40
+ output_preds = model(drug_inputs, prot_inputs)
41
+
42
+ predict = torch.squeeze((output_preds)).tolist()
43
+
44
+ # output_preds = torch.relu(output_preds)
45
+ # predict = torch.tanh(output_preds)
46
+ # predict = predict.squeeze(dim=1).tolist()
47
+
48
+ return predict
49
+
50
+
51
+ def biomarker_prediction(smile_acc, smile_don):
52
+ try:
53
+ aas_input = smile_acc
54
+
55
+
56
+ das_input =smile_don
57
+ d_inputs = tokenizer(aas_input, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
58
+ # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
59
+ drug_input_ids = d_inputs['input_ids'].to(device)
60
+ drug_attention_mask = d_inputs['attention_mask'].to(device)
61
+ drug_inputs = {'input_ids': drug_input_ids, 'attention_mask': drug_attention_mask}
62
+
63
+ p_inputs = prot_tokenizer(das_input, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
64
+ # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
65
+ prot_input_ids = p_inputs['input_ids'].to(device)
66
+ prot_attention_mask = p_inputs['attention_mask'].to(device)
67
+ prot_inputs = {'input_ids': prot_input_ids, 'attention_mask': prot_attention_mask}
68
+
69
+ output_predict = get_biomarker(drug_inputs, prot_inputs)
70
+
71
+ return output_predict
72
+
73
+ except Exception as e:
74
+ print(e)
75
+ return {'Error_message': e}
76
+
77
+
78
+ def smiles_aas_test(smile_acc,smile_don):
79
+
80
+ batch_size = 1
81
+ try:
82
+ output_pred = biomarker_prediction((smile_acc), (smile_don))
83
+
84
+ datas = output_pred
85
+
86
+ ## -- Export result data to csv -- ##
87
+ # df = pd.DataFrame(datas)
88
+ # df.to_csv('./results/predict_test.csv', index=None)
89
+
90
+ # print(df)
91
+ return datas
92
+
93
+ except Exception as e:
94
+ print(e)
95
+ return {'Error_message': e}
96
+
97
+
98
+ if __name__ == "__main__":
99
+ a = smiles_aas_test(smile_acc,smile_don)
100
+
101
+
102
+
train.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3
+ from curses import delay_output
4
+ import gc, os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import wandb
8
+ from scipy.stats import pearsonr
9
+ from util.utils import *
10
+ from util.attention_flow import *
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ import sklearn as sk
16
+ from torch.utils.data import Dataset, DataLoader
17
+
18
+ import pytorch_lightning as pl
19
+ from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
20
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
21
+ from transformers import AutoConfig, AutoTokenizer, RobertaModel, BertModel
22
+ from sklearn.metrics import r2_score, mean_absolute_error,mean_squared_error
23
+
24
+ class markerDataset(Dataset):
25
+ def __init__(self, list_IDs, labels, df_dti, d_tokenizer, p_tokenizer):
26
+ 'Initialization'
27
+ self.labels = labels
28
+ self.list_IDs = list_IDs
29
+ self.df = df_dti
30
+
31
+ self.d_tokenizer = d_tokenizer
32
+ self.p_tokenizer = p_tokenizer
33
+
34
+
35
+
36
+ def convert_data(self, acc_data, don_data):
37
+
38
+
39
+ d_inputs = self.d_tokenizer(acc_data, return_tensors="pt")
40
+ p_inputs = self.d_tokenizer(don_data, return_tensors="pt")
41
+
42
+ acc_input_ids = d_inputs['input_ids']
43
+ acc_attention_mask = d_inputs['attention_mask']
44
+ acc_inputs = {'input_ids': acc_input_ids, 'attention_mask': acc_attention_mask}
45
+
46
+ don_input_ids = p_inputs['input_ids']
47
+ don_attention_mask = p_inputs['attention_mask']
48
+ don_inputs = {'input_ids': don_input_ids, 'attention_mask': don_attention_mask}
49
+
50
+ return acc_inputs, don_inputs
51
+
52
+ def tokenize_data(self, acc_data, don_data):
53
+
54
+ tokenize_acc = ['[CLS]'] + self.d_tokenizer.tokenize(acc_data) + ['[SEP]']
55
+
56
+ tokenize_don = ['[CLS]'] + self.p_tokenizer.tokenize(don_data) + ['[SEP]']
57
+
58
+ return tokenize_acc, tokenize_don
59
+
60
+ def __len__(self):
61
+ 'Denotes the total number of samples'
62
+ return len(self.list_IDs)
63
+
64
+ def __getitem__(self, index):
65
+ 'Generates one sample of data'
66
+ index = self.list_IDs[index]
67
+ acc_data = self.df.iloc[index]['acceptor']
68
+ don_data = self.df.iloc[index]['donor']
69
+
70
+ d_inputs = self.d_tokenizer(acc_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
71
+ p_inputs = self.p_tokenizer(don_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
72
+
73
+ d_input_ids = d_inputs['input_ids'].squeeze()
74
+ d_attention_mask = d_inputs['attention_mask'].squeeze()
75
+ p_input_ids = p_inputs['input_ids'].squeeze()
76
+ p_attention_mask = p_inputs['attention_mask'].squeeze()
77
+
78
+ labels = torch.as_tensor(self.labels[index], dtype=torch.float)
79
+
80
+ dataset = [d_input_ids, d_attention_mask, p_input_ids, p_attention_mask, labels]
81
+ return dataset
82
+
83
+
84
+ class markerDataModule(pl.LightningDataModule):
85
+ def __init__(self, task_name, acc_model_name, don_model_name, num_workers, batch_size, traindata_rate = 1.0):
86
+ super().__init__()
87
+ self.batch_size = batch_size
88
+ self.num_workers = num_workers
89
+ self.task_name = task_name
90
+
91
+ self.traindata_rate = traindata_rate
92
+
93
+ self.d_tokenizer = AutoTokenizer.from_pretrained(acc_model_name)
94
+ self.p_tokenizer = AutoTokenizer.from_pretrained(don_model_name)
95
+
96
+ self.df_train = None
97
+ self.df_val = None
98
+ self.df_test = None
99
+
100
+ self.load_testData = True
101
+
102
+ self.train_dataset = None
103
+ self.valid_dataset = None
104
+ self.test_dataset = None
105
+
106
+ def get_task(self, task_name):
107
+ if task_name.lower() == 'OSC':
108
+ return './dataset/OSC/'
109
+
110
+ elif task_name.lower() == 'merge':
111
+ self.load_testData = False
112
+ return './dataset/MergeDataset'
113
+
114
+ def prepare_data(self):
115
+ # Use this method to do things that might write to disk or that need to be done only from
116
+ # a single process in distributed settings.
117
+ dataFolder = './dataset/OSC'
118
+
119
+ self.df_train = pd.read_csv(dataFolder + '/train.csv')
120
+ self.df_val = pd.read_csv(dataFolder + '/val.csv')
121
+
122
+ ## -- Data Lenght Rate apply -- ##
123
+ traindata_length = int(len(self.df_train) * self.traindata_rate)
124
+ validdata_length = int(len(self.df_val) * self.traindata_rate)
125
+
126
+ self.df_train = self.df_train[:traindata_length]
127
+ self.df_val = self.df_val[:validdata_length]
128
+
129
+ if self.load_testData is True:
130
+ self.df_test = pd.read_csv(dataFolder + '/test.csv')
131
+
132
+ def setup(self, stage=None):
133
+ if stage == 'fit' or stage is None:
134
+ self.train_dataset = markerDataset(self.df_train.index.values, self.df_train.Label.values, self.df_train,
135
+ self.d_tokenizer, self.p_tokenizer)
136
+ self.valid_dataset = markerDataset(self.df_val.index.values, self.df_val.Label.values, self.df_val,
137
+ self.d_tokenizer, self.p_tokenizer)
138
+
139
+ if self.load_testData is True:
140
+ self.test_dataset = markerDataset(self.df_test.index.values, self.df_test.Label.values, self.df_test,
141
+ self.d_tokenizer, self.p_tokenizer)
142
+
143
+ def train_dataloader(self):
144
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
145
+
146
+ def val_dataloader(self):
147
+ return DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
148
+
149
+ def test_dataloader(self):
150
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
151
+
152
+
153
+ class markerModel(pl.LightningModule):
154
+ def __init__(self, acc_model_name, don_model_name, lr, dropout, layer_features, loss_fn = "smooth", layer_limit = True, d_pretrained=True, p_pretrained=True):
155
+ super().__init__()
156
+ self.lr = lr
157
+ self.loss_fn = loss_fn
158
+ self.criterion = torch.nn.MSELoss()
159
+ self.criterion_smooth = torch.nn.SmoothL1Loss()
160
+ # self.sigmoid = nn.Sigmoid()
161
+
162
+ #-- Pretrained Model Setting
163
+ acc_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
164
+ if d_pretrained is False:
165
+ self.d_model = RobertaModel(acc_config)
166
+ print('acceptor model without pretraining')
167
+ else:
168
+ self.d_model = RobertaModel.from_pretrained(acc_model_name, num_labels=2,
169
+ output_hidden_states=True,
170
+ output_attentions=True)
171
+
172
+ don_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
173
+
174
+ if p_pretrained is False:
175
+ self.p_model = RobertaModel(don_config)
176
+ print('donor model without pretraining')
177
+ else:
178
+ self.p_model = RobertaModel.from_pretrained(don_model_name,
179
+ output_hidden_states=True,
180
+ output_attentions=True)
181
+
182
+ #-- Decoder Layer Setting
183
+ layers = []
184
+ firstfeature = self.d_model.config.hidden_size + self.p_model.config.hidden_size
185
+ for feature_idx in range(0, len(layer_features) - 1):
186
+ layers.append(nn.Linear(firstfeature, layer_features[feature_idx]))
187
+ firstfeature = layer_features[feature_idx]
188
+
189
+ if feature_idx is len(layer_features)-2:
190
+ layers.append(nn.ReLU())
191
+ else:
192
+ layers.append(nn.ReLU())
193
+
194
+ if dropout > 0:
195
+ layers.append(nn.Dropout(dropout))
196
+
197
+ layers.append(nn.Linear(firstfeature, layer_features[-1]))
198
+
199
+ self.decoder = nn.Sequential(*layers)
200
+
201
+ self.save_hyperparameters()
202
+
203
+ def forward(self, acc_inputs, don_inputs):
204
+
205
+ d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
206
+ p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
207
+
208
+ outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
209
+ outs = self.decoder(outs)
210
+
211
+ return outs
212
+
213
+ def attention_output(self, acc_inputs, don_inputs):
214
+
215
+ d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
216
+ p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
217
+
218
+ outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
219
+ outs = self.decoder(outs)
220
+
221
+ return d_outputs['attentions'], p_outputs['attentions'], outs
222
+
223
+ def training_step(self, batch, batch_idx):
224
+
225
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
226
+
227
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
228
+
229
+ labels = batch[4]
230
+
231
+ output = self(acc_inputs, don_inputs)
232
+ logits = output.squeeze(dim=1)
233
+
234
+ if self.loss_fn == 'MSE':
235
+ loss = self.criterion(logits, labels)
236
+ else:
237
+ loss = self.criterion_smooth(logits, labels)
238
+
239
+ self.log("train_loss", loss, on_step=False, on_epoch=True, logger=True)
240
+ # print("train_loss", loss)
241
+ return {"loss": loss}
242
+
243
+ def validation_step(self, batch, batch_idx):
244
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
245
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
246
+ labels = batch[4]
247
+
248
+ output = self(acc_inputs, don_inputs)
249
+ logits = output.squeeze(dim=1)
250
+
251
+
252
+ if self.loss_fn == 'MSE':
253
+ loss = self.criterion(logits, labels)
254
+ else:
255
+ loss = self.criterion_smooth(logits, labels)
256
+
257
+ self.log("valid_loss", loss, on_step=False, on_epoch=True, logger=True)
258
+ # print("valid_loss", loss)
259
+ return {"logits": logits, "labels": labels}
260
+
261
+ def validation_step_end(self, outputs):
262
+ return {"logits": outputs['logits'], "labels": outputs['labels']}
263
+
264
+ def validation_epoch_end(self, outputs):
265
+ preds = self.convert_outputs_to_preds(outputs)
266
+ labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
267
+
268
+ mae, mse, r2,r = self.log_score(preds, labels)
269
+
270
+ self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
271
+ self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
272
+
273
+ self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
274
+
275
+ def test_step(self, batch, batch_idx):
276
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
277
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
278
+ labels = batch[4]
279
+
280
+ output = self(acc_inputs, don_inputs)
281
+ logits = output.squeeze(dim=1)
282
+
283
+ if self.loss_fn == 'MSE':
284
+ loss = self.criterion(logits, labels)
285
+ else:
286
+ loss = self.criterion_smooth(logits, labels)
287
+
288
+ self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True)
289
+ return {"logits": logits, "labels": labels}
290
+
291
+ def test_step_end(self, outputs):
292
+ return {"logits": outputs['logits'], "labels": outputs['labels']}
293
+
294
+ def test_epoch_end(self, outputs):
295
+ preds = self.convert_outputs_to_preds(outputs)
296
+ labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
297
+
298
+ mae, mse, r2,r = self.log_score(preds, labels)
299
+
300
+ self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
301
+ self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
302
+ self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
303
+ self.log("r", r, on_step=False, on_epoch=True, logger=True)
304
+ def configure_optimizers(self):
305
+
306
+ param_optimizer = list(self.named_parameters())
307
+
308
+ no_decay = ["bias", "gamma", "beta"]
309
+ optimizer_grouped_parameters = [
310
+ {
311
+ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
312
+ "weight_decay_rate": 0.0001
313
+ },
314
+ {
315
+ "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
316
+ "weight_decay_rate": 0.0
317
+ },
318
+ ]
319
+ optimizer = torch.optim.AdamW(
320
+ optimizer_grouped_parameters,
321
+ lr=self.lr,
322
+ )
323
+ return optimizer
324
+
325
+ def convert_outputs_to_preds(self, outputs):
326
+ logits = torch.cat([output['logits'] for output in outputs], dim=0)
327
+ return logits
328
+
329
+ def log_score(self, preds, labels):
330
+ y_pred = preds.detach().cpu().numpy()
331
+ y_label = labels.detach().cpu().numpy()
332
+
333
+ mae = mean_absolute_error(y_label, y_pred)
334
+ mse = mean_squared_error(y_label, y_pred)
335
+ r2=r2_score(y_label, y_pred)
336
+ r = pearsonr(y_label, y_pred)
337
+ print(f'\nmae : {mae}')
338
+ print(f'mse : {mse}')
339
+ print(f'r2 : {r2}')
340
+ print(f'r : {r}')
341
+
342
+ return mae, mse, r2, r
343
+
344
+
345
+ def main_wandb(config=None):
346
+ try:
347
+ if config is not None:
348
+ wandb.init(config=config, project=project_name)
349
+ else:
350
+ wandb.init(settings=wandb.Settings(console='off'))
351
+
352
+ config = wandb.config
353
+ pl.seed_everything(seed=config.num_seed)
354
+
355
+ dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
356
+ config.num_workers, config.batch_size, config.prot_maxlength, config.traindata_rate)
357
+ dm.prepare_data()
358
+ dm.setup()
359
+
360
+ model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
361
+ #model_logger = WandbLogger(project=project_name)
362
+ checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mae", mode="max")
363
+
364
+ trainer = pl.Trainer(
365
+ max_epochs=config.max_epoch,
366
+ precision=16,
367
+ #logger=model_logger,
368
+ callbacks=[checkpoint_callback],
369
+ accelerator='cpu',log_every_n_steps=40
370
+ )
371
+
372
+
373
+ if config.model_mode == "train":
374
+ model = markerModel(config.d_model_name, config.p_model_name,
375
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
376
+ model.train()
377
+ trainer.fit(model, datamodule=dm)
378
+
379
+ model.eval()
380
+ trainer.test(model, datamodule=dm)
381
+
382
+ else:
383
+ model = markerModel.load_from_checkpoint(config.load_checkpoint)
384
+
385
+ model.eval()
386
+ trainer.test(model, datamodule=dm)
387
+
388
+ except Exception as e:
389
+ print(e)
390
+
391
+
392
+ def main_default(config):
393
+ try:
394
+ config = DictX(config)
395
+ pl.seed_everything(seed=config.num_seed)
396
+
397
+ dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
398
+ config.num_workers, config.batch_size, config.traindata_rate)
399
+
400
+ dm.prepare_data()
401
+ dm.setup()
402
+ model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
403
+ # model_logger = TensorBoardLogger("./log", name=f"{config.task_name}_{model_type}_{config.num_seed}")
404
+ checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mse", mode="max")
405
+
406
+ trainer = pl.Trainer(
407
+ max_epochs=config.max_epoch,
408
+ precision= 32,
409
+ # logger=model_logger,
410
+ callbacks=[checkpoint_callback],
411
+ accelerator='cpu',log_every_n_steps=40
412
+ )
413
+
414
+
415
+ if config.model_mode == "train":
416
+ model = markerModel(config.d_model_name, config.p_model_name,
417
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
418
+
419
+ model.train()
420
+
421
+ trainer.fit(model, datamodule=dm)
422
+
423
+ model.eval()
424
+ trainer.test(model, datamodule=dm)
425
+
426
+ else:
427
+ model = markerModel.load_from_checkpoint(config.load_checkpoint)
428
+
429
+ model.eval()
430
+ trainer.test(model, datamodule=dm)
431
+ except Exception as e:
432
+ print(e)
433
+
434
+
435
+ if __name__ == '__main__':
436
+ using_wandb = False
437
+
438
+ if using_wandb == True:
439
+ #-- hyper param config file Load --##
440
+ config = load_hparams('config/config_hparam.json')
441
+ project_name = config["name"]
442
+
443
+ main_wandb(config)
444
+
445
+ ##-- wandb Sweep Hyper Param Tuning --##
446
+ # config = load_hparams('config/config_sweep_bindingDB.json')
447
+ # project_name = config["name"]
448
+ # sweep_id = wandb.sweep(config, project=project_name)
449
+ # wandb.agent(sweep_id, main_wandb)
450
+
451
+ else:
452
+ config = load_hparams('config/config_hparam.json')
453
+
454
+ main_default(config)