|
|
|
""" |
|
Created on Mon May 1 19:41:07 2023 |
|
|
|
@author: Sen |
|
""" |
|
|
|
import os |
|
import subprocess |
|
import warnings |
|
from tqdm import tqdm |
|
import argparse |
|
import torch |
|
from transformers import AutoTokenizer, GPT2LMHeadModel |
|
|
|
warnings.filterwarnings('ignore') |
|
os.environ["http_proxy"] = "http://127.0.0.1:7890" |
|
os.environ["https_proxy"] = "http://127.0.0.1:7890" |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-p', type=str, default=None, help='Input the protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.') |
|
parser.add_argument('-f', type=str, default=None, help='Input the FASTA file. Default value is None. Only one of -p and -f should be specified.') |
|
parser.add_argument('-l', type=str, default='', help='Input the ligand prompt. Default value is an empty string.') |
|
parser.add_argument('-n', type=int, default=100, help='Number of output molecules to generate. Default value is 100.') |
|
parser.add_argument('-d', type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.") |
|
parser.add_argument('-o', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.") |
|
|
|
args = parser.parse_args() |
|
|
|
protein_seq = args.p |
|
fasta_file = args.f |
|
ligand_prompt = args.l |
|
num_generated = args.n |
|
device = args.d |
|
output_path = args.o |
|
|
|
|
|
def ifno_mkdirs(dirname): |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
|
|
ifno_mkdirs(output_path) |
|
|
|
|
|
def read_fasta_file(file_path): |
|
with open(file_path, 'r') as fasta_file: |
|
sequence = [] |
|
|
|
for line in fasta_file: |
|
line = line.strip() |
|
if not line.startswith('>'): |
|
sequence.append(line) |
|
|
|
protein_sequence = ''.join(sequence) |
|
|
|
return protein_sequence |
|
|
|
|
|
if (protein_seq is not None) != (fasta_file is not None): |
|
if fasta_file is not None: |
|
protein_seq = read_fasta_file(fasta_file) |
|
else: |
|
protein_seq = protein_seq |
|
else: |
|
print("The input should be either a protein amino acid sequence or a FASTA file, but not both.") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt') |
|
model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt") |
|
|
|
|
|
p_prompt = "<|startoftext|><P>" + protein_seq + "<L>" |
|
l_prompt = "" + ligand_prompt |
|
prompt = p_prompt + l_prompt |
|
print(prompt) |
|
|
|
|
|
model.eval() |
|
device = torch.device(device) |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
def get_sdf(ligand_list,output_path): |
|
for ligand in tqdm(ligand_list): |
|
filename = output_path + 'ligand_' + ligand +'.sdf' |
|
cmd = "obabel -:" + ligand + " -osdf -O " + filename + " --gen3d --forcefield mmff94" |
|
|
|
try: |
|
|
|
output = subprocess.check_output(cmd, timeout=10) |
|
except subprocess.TimeoutExpired: |
|
pass |
|
|
|
def filter_sdf(output_path): |
|
filelist = os.listdir(output_path) |
|
for filename in filelist: |
|
filepath = os.path.join(output_path,filename) |
|
with open(filepath,'r') as f: |
|
text = f.read() |
|
if len(text)<2: |
|
os.remove(filepath) |
|
|
|
|
|
|
|
|
|
|
|
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) |
|
generated = generated.to(device) |
|
|
|
|
|
for i in range(100): |
|
ligand_list = [] |
|
sample_outputs = model.generate( |
|
generated, |
|
|
|
do_sample=True, |
|
top_k=5, |
|
max_length = 1024, |
|
top_p=0.6, |
|
num_return_sequences=64 |
|
) |
|
|
|
for i, sample_output in enumerate(sample_outputs): |
|
ligand_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1]) |
|
torch.cuda.empty_cache() |
|
|
|
get_sdf(ligand_list,output_path) |
|
filter_sdf(output_path) |
|
|
|
if len(os.listdir(output_path))>num_generated: |
|
break |
|
else:pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|