liyuesen commited on
Commit
317fa29
1 Parent(s): 4e4b4b1

Upload drug_generator.py

Browse files
Files changed (1) hide show
  1. drug_generator.py +143 -0
drug_generator.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon May 1 19:41:07 2023
4
+
5
+ @author: Sen
6
+ """
7
+
8
+ import os
9
+ import subprocess
10
+ import warnings
11
+ from tqdm import tqdm
12
+ import argparse
13
+ import torch
14
+ from transformers import AutoTokenizer, GPT2LMHeadModel
15
+
16
+ warnings.filterwarnings('ignore')
17
+ os.environ["http_proxy"] = "http://127.0.0.1:7890"
18
+ os.environ["https_proxy"] = "http://127.0.0.1:7890"
19
+
20
+
21
+ # Set up command line argument parsing
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('-p', type=str, default=None, help='Input the protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.')
24
+ parser.add_argument('-f', type=str, default=None, help='Input the FASTA file. Default value is None. Only one of -p and -f should be specified.')
25
+ parser.add_argument('-l', type=str, default='', help='Input the ligand prompt. Default value is an empty string.')
26
+ parser.add_argument('-n', type=int, default=100, help='Number of output molecules to generate. Default value is 100.')
27
+ parser.add_argument('-d', type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.")
28
+ parser.add_argument('-o', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.")
29
+
30
+ args = parser.parse_args()
31
+
32
+ protein_seq = args.p
33
+ fasta_file = args.f
34
+ ligand_prompt = args.l
35
+ num_generated = args.n
36
+ device = args.d
37
+ output_path = args.o
38
+
39
+
40
+ def ifno_mkdirs(dirname):
41
+ if not os.path.exists(dirname):
42
+ os.makedirs(dirname)
43
+
44
+ ifno_mkdirs(output_path)
45
+
46
+ # Function to read in FASTA file
47
+ def read_fasta_file(file_path):
48
+ with open(file_path, 'r') as fasta_file:
49
+ sequence = []
50
+
51
+ for line in fasta_file:
52
+ line = line.strip()
53
+ if not line.startswith('>'):
54
+ sequence.append(line)
55
+
56
+ protein_sequence = ''.join(sequence)
57
+
58
+ return protein_sequence
59
+
60
+ # Check if the input is either a protein amino acid sequence or a FASTA file, but not both
61
+ if (protein_seq is not None) != (fasta_file is not None):
62
+ if fasta_file is not None:
63
+ protein_seq = read_fasta_file(fasta_file)
64
+ else:
65
+ protein_seq = protein_seq
66
+ else:
67
+ print("The input should be either a protein amino acid sequence or a FASTA file, but not both.")
68
+
69
+ # Load the tokenizer and the model
70
+ tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt')
71
+ model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt")
72
+
73
+ # Generate a prompt for the model
74
+ p_prompt = "<|startoftext|><P>" + protein_seq + "<L>"
75
+ l_prompt = "" + ligand_prompt
76
+ prompt = p_prompt + l_prompt
77
+ print(prompt)
78
+
79
+ # Move the model to the specified device
80
+ model.eval()
81
+ device = torch.device(device)
82
+ model.to(device)
83
+
84
+
85
+
86
+ #Define post-processing function
87
+ #Define function to generate SDF files from a list of ligand SMILES using OpenBabel
88
+ def get_sdf(ligand_list,output_path):
89
+ for ligand in tqdm(ligand_list):
90
+ filename = output_path + 'ligand_' + ligand +'.sdf'
91
+ cmd = "obabel -:" + ligand + " -osdf -O " + filename + " --gen3d --forcefield mmff94"# --conformer --nconf 1 --score rmsd
92
+ #subprocess.check_call(cmd, shell=True)
93
+ try:
94
+ # 设置超时时间为 30 秒
95
+ output = subprocess.check_output(cmd, timeout=10)
96
+ except subprocess.TimeoutExpired:
97
+ pass
98
+ #Define function to filter out empty SDF files
99
+ def filter_sdf(output_path):
100
+ filelist = os.listdir(output_path)
101
+ for filename in filelist:
102
+ filepath = os.path.join(output_path,filename)
103
+ with open(filepath,'r') as f:
104
+ text = f.read()
105
+ if len(text)<2:
106
+ os.remove(filepath)
107
+
108
+
109
+
110
+
111
+ # Generate molecules
112
+ generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
113
+ generated = generated.to(device)
114
+
115
+
116
+ for i in range(100):
117
+ ligand_list = []
118
+ sample_outputs = model.generate(
119
+ generated,
120
+ #bos_token_id=random.randint(1,30000),
121
+ do_sample=True,
122
+ top_k=5,
123
+ max_length = 1024,
124
+ top_p=0.6,
125
+ num_return_sequences=64
126
+ )
127
+
128
+ for i, sample_output in enumerate(sample_outputs):
129
+ ligand_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1])
130
+ torch.cuda.empty_cache()
131
+
132
+ get_sdf(ligand_list,output_path)
133
+ filter_sdf(output_path)
134
+
135
+ if len(os.listdir(output_path))>num_generated:
136
+ break
137
+ else:pass
138
+
139
+
140
+
141
+
142
+
143
+