Spaces:
Runtime error
Runtime error
# credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py | |
import os | |
import sys | |
from urllib import request | |
import esm | |
import gradio as gr | |
import progres as pg | |
import requests | |
import torch | |
from transformers import (AutoModel, AutoModelForMaskedLM, AutoTokenizer, | |
EsmModel) | |
import msa | |
import proteinbind_new | |
tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") | |
model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") | |
model_nt.eval() | |
tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") | |
model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D") | |
model_aa.eval() | |
tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') | |
model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2') | |
model_se.eval() | |
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() | |
msa_transformer = msa_transformer.eval() | |
msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter() | |
model = proteinbind_new.create_proteinbind(True) | |
def pass_through(torch_output, key: str): | |
device = torch.device("cpu") | |
input_data = { | |
key: torch_output.type(torch.float32).to(device) | |
} | |
output = model(input_data) | |
return output[key].detach().numpy() | |
def nt_embed(sequence: str): | |
tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"] | |
attention_mask = tokens_ids != tokenizer_nt.pad_token_id | |
with torch.no_grad(): | |
torch_outs = model_nt( | |
tokens_ids, # .to('cuda'), | |
attention_mask=attention_mask, # .to('cuda'), | |
output_hidden_states=True | |
) | |
last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0] | |
return pass_through(last_layer_CLS, "dna") | |
def aa_embed(sequence: str): | |
tokens = tokenizer_aa([sequence], return_tensors="pt") | |
with torch.no_grad(): | |
torch_outs = model_aa(**tokens) | |
return pass_through(torch_outs[0], "aa") | |
def se_embed(sentence: str): | |
encoded_input = tokenizer_se([sentence], return_tensors='pt') | |
with torch.no_grad(): | |
model_output = model_se(**encoded_input) | |
return pass_through(model_output[0], "text") | |
def msa_embed(sequences: list): | |
inputs = msa.greedy_select(sequences, num_seqs=128) # can change this to pass more/fewer sequences | |
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs]) | |
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device) | |
with torch.no_grad(): | |
temp = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12])['representations'] | |
temp = temp[12][:, :, 0, :] | |
temp = torch.mean(temp, (0, 1)) | |
return pass_through(temp, "msa") | |
def go_embed(terms): | |
pass | |
def download_data_if_required(): | |
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files" | |
fps = [pg.trained_model_fp] | |
urls = [f"{url_base}/trained_model.pt"] | |
# for targetdb in pre_embedded_dbs: | |
# fps.append(os.path.join(database_dir, targetdb + ".pt")) | |
# urls.append(f"{url_base}/{targetdb}.pt") | |
if not os.path.isdir(pg.trained_model_dir): | |
os.makedirs(pg.trained_model_dir) | |
# if not os.path.isdir(database_dir): | |
# os.makedirs(database_dir) | |
printed = False | |
for fp, url in zip(fps, urls): | |
if not os.path.isfile(fp): | |
if not printed: | |
print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir, | |
", internet connection required, this can take a few minutes", | |
sep="", file=sys.stderr) | |
printed = True | |
try: | |
request.urlretrieve(url, fp) | |
d = torch.load(fp, map_location="cpu") | |
if fp == pg.trained_model_fp: | |
assert "model" in d | |
else: | |
assert "embeddings" in d | |
except Exception: | |
if os.path.isfile(fp): | |
os.remove(fp) | |
print("Failed to download from", url, "and save to", fp, file=sys.stderr) | |
print("Exiting", file=sys.stderr) | |
sys.exit(1) | |
if printed: | |
print("Data downloaded successfully", file=sys.stderr) | |
def get_pdb(pdb_code="", filepath=""): | |
if pdb_code is None or pdb_code == "": | |
try: | |
with open(filepath.name) as f: | |
return f.read() | |
except AttributeError: | |
return None | |
else: | |
return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode() | |
def molecule(pdb): | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 600px; | |
position: relative; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<script> | |
let pdb = `""" | |
+ pdb | |
+ """` | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "black" }; | |
let viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(pdb, "pdb"); | |
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } }); | |
viewer.addSurface("MS", { opacity: .5, color: "white" }); | |
viewer.zoomTo(); | |
viewer.render(); | |
viewer.zoom(0.8, 2000); | |
}) | |
</script> | |
</body></html>""" | |
) | |
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def str2coords(s): | |
coords = [] | |
for line in s.split('\n'): | |
if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA": | |
coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])]) | |
elif line.startswith("ENDMDL"): | |
break | |
return coords | |
def update_st(inp, file): | |
pdb = get_pdb(inp, file) | |
new_coords = pass_through(pg.embed_coords(str2coords(pdb)), "pdb") | |
return (molecule(pdb), new_coords) | |
def update_nt(inp): | |
return str(nt_embed(inp or '')) | |
def update_aa(inp): | |
return str(aa_embed(inp)) | |
def update_se(inp): | |
return str(se_embed(inp)) | |
def update_go(inp): | |
return str(go_embed(inp)) | |
def update_msa(inp): | |
return str(msa_embed(msa.read_msa(inp.name))) | |
demo = gr.Blocks() | |
with demo: | |
with gr.Tabs(): | |
with gr.TabItem("PDB Structural Embeddings"): | |
with gr.Row(): | |
with gr.Box(): | |
inp = gr.Textbox( | |
placeholder="PDB Code or upload file below", label="Input structure" | |
) | |
file = gr.File(file_count="single") | |
gr.Examples(["2CBA", "6VXX"], inp) | |
btn = gr.Button("View structure") | |
gr.Markdown("# PDB viewer using 3Dmol.js") | |
mol = gr.HTML() | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb]) | |
with gr.TabItem("Nucleotide Sequence Embeddings"): | |
with gr.Box(): | |
inp = gr.Textbox( | |
placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence" | |
) | |
btn = gr.Button("View embeddings") | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_nt, inputs=[inp], outputs=emb) | |
with gr.TabItem("Amino Acid Sequence Embeddings"): | |
with gr.Box(): | |
inp = gr.Textbox( | |
placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence" | |
) | |
btn = gr.Button("View embeddings") | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_aa, inputs=[inp], outputs=emb) | |
with gr.TabItem("Sentence Embeddings"): | |
with gr.Box(): | |
inp = gr.Textbox( | |
placeholder="Your text here", label="Input Sentence" | |
) | |
btn = gr.Button("View embeddings") | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_se, inputs=[inp], outputs=emb) | |
with gr.TabItem("MSA Embeddings"): | |
with gr.Box(): | |
inp = gr.File(file_count="single", label="Input MSA") | |
btn = gr.Button("View embeddings") | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_msa, inputs=[inp], outputs=emb) | |
with gr.TabItem("GO Embeddings"): | |
with gr.Box(): | |
inp = gr.Textbox( | |
placeholder="", label="Input GO Terms" | |
) | |
btn = gr.Button("View embeddings") | |
emb = gr.Textbox(interactive=False) | |
btn.click(fn=update_go, inputs=[inp], outputs=emb) | |
if __name__ == "__main__": | |
download_data_if_required() | |
demo.launch() | |