from gradio import Blocks, Button, Checkbox, DataFrame, DownloadButton, Dropdown, Examples, Image, Markdown, Tab, Textbox from model import get_models from data import Data # Define scoring strategies SCORING = ["wt-marginals", "masked-marginals"] # Get available models MODELS = get_models() def app(*argv): """ Main application function """ # Unpack the arguments seq, trg, model_name, *_ = argv scoring = SCORING[scoring_strategy.value] # Calculate the data based on the input parameters data = Data(seq, trg, model_name, scoring).calculate() if isinstance(data.image(), str): out = Image(value=data.image(), type='filepath', visible=True), DataFrame(visible=False) else: out = Image(visible=False), DataFrame(value=data.image(), visible=True) return *out, DownloadButton(value=data.csv(), visible=True) # Create the Gradio interface with Blocks() as esm_scan: Markdown("# [ESM-Scan](https://doi.org/10.1002/pro.5221)") # Define the interface components with Tab("App"): Markdown(open("header.md", "r", encoding="utf-8").read()) seq = Textbox( lines=2, label="Sequence", placeholder="FASTA sequence here...", value='' ) trg = Textbox( lines=1, label="Substitutions", placeholder="Substitutions here...", value="" ) model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D") scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True) btn = Button(value="Run", variant="primary") dlb = DownloadButton(label="Download raw data", visible=False) out = Image(visible=False) ouu = DataFrame(visible=False) btn.click( fn=app, inputs=[seq, trg, model_name], outputs=[out, ouu, dlb] ) ex = Examples( examples=[ [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ", "deep mutational scanning", "facebook/esm2_t6_8M_UR50D" ], [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ", "217 218 219", "facebook/esm2_t12_35M_UR50D" ], [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ", "R218K R218S R218N R218A R218V R218D", "facebook/esm2_t30_150M_UR50D", ], [ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ", "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ", "facebook/esm2_t33_650M_UR50D", ], ], inputs=[seq, trg, model_name], outputs=[out], fn=app, cache_examples=False ) with Tab("Instructions"): Markdown(open("instructions.md", "r", encoding="utf-8").read()) # Launch the Gradio interface if __name__ == "__main__": esm_scan.launch()