zsp / app.py
MassimoGregorioTotaro
instructions reorganisation
8ecc9a8
from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Examples, Image, Markdown, Tab, Textbox
from model import get_models
from data import Data
# Define scoring strategies
SCORING = ["wt-marginals", "masked-marginals"]
# Get available models
MODELS = get_models()
def app(*argv):
"""
Main application function
"""
# Unpack the arguments
seq, trg, model_name, *_ = argv
scoring = SCORING[scoring_strategy.value]
# Calculate the data based on the input parameters
data = Data(seq, trg, model_name, scoring).calculate()
if isinstance(data.image(), str):
out = Image(value=data.image(), type='filepath', visible=True), DataFrame(visible=False)
else:
out = Image(visible=False), DataFrame(value=data.image(), visible=True)
return *out, DownloadButton(value=data.csv(), visible=True)
# Create the Gradio interface
with Blocks() as esm_scan:
Markdown("# [ESM-Scan](https://doi.org/10.1002/pro.5221)")
# Define the interface components
with Tab("App"):
Markdown(open("header.md", "r", encoding="utf-8").read())
seq = Textbox(
lines=2,
label="Sequence",
placeholder="FASTA sequence here...",
value=''
)
trg = Textbox(
lines=1,
label="Substitutions",
placeholder="Substitutions here...",
value=""
)
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
btn = Button(value="Run", variant="primary")
dlb = DownloadButton(label="Download raw data", visible=False)
out = Image(visible=False)
ouu = DataFrame(visible=False)
btn.click(
fn=app,
inputs=[seq, trg, model_name],
outputs=[out, ouu, dlb]
)
ex = Examples(
examples=[
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"deep mutational scanning",
"facebook/esm2_t6_8M_UR50D"
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"217 218 219",
"facebook/esm2_t12_35M_UR50D"
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"R218K R218S R218N R218A R218V R218D",
"facebook/esm2_t30_150M_UR50D",
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"facebook/esm2_t33_650M_UR50D",
],
],
inputs=[seq,
trg,
model_name],
outputs=[out],
fn=app,
cache_examples=False
)
with Tab("Instructions"):
Markdown(open("instructions.md", "r", encoding="utf-8").read())
# Launch the Gradio interface
if __name__ == "__main__":
esm_scan.launch()