Spaces:
Running
Running
import os | |
import base64 | |
import gradio as gr | |
import pandas as pd | |
from apscheduler.schedulers.background import BackgroundScheduler | |
import numpy as np | |
from src.about import ( | |
CITATION_BUTTON_LABEL, | |
CITATION_BUTTON_TEXT, | |
) | |
from src.display.css_html_js import custom_css | |
import copy | |
from src.envs import API, REPO_ID | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
with open(os.path.join(current_dir, "images/pb_logo.png"), "rb") as image_file: | |
main_logo = base64.b64encode(image_file.read()).decode('utf-8') | |
def restart_space(): | |
API.restart_space(repo_id=REPO_ID) | |
TITLE=""" | |
# ProteinBench: A Holistic Evaluation of Protein Foundation Models""" | |
INTRO_TEXT=""" | |
Recent years have witnessed a surge in the development of protein foundation models, | |
significantly improving performance in protein prediction and generative tasks | |
ranging from 3D structure prediction and protein design to conformational dynamics. | |
However, the capabilities and limitations associated with these models remain poorly understood due to the absence of a unified evaluation framework. | |
To fill this gap, we introduce <b>ProteinBench</b>, | |
a holistic evaluation framework designed to enhance the transparency of protein foundation models. | |
Our approach consists of three key components: | |
(i) A taxonomic classification of tasks that broadly encompass the main challenges in the protein domain, | |
based on the relationships between different protein modalities; | |
(ii) A multi-metric evaluation approach that assesses performance across four key dimensions: quality, novelty, diversity, and robustness; | |
and (iii) In-depth analyses from various user objectives, providing a holistic view of model performance. | |
Our comprehensive evaluation of protein foundation models reveals several key findings that shed light on their current capabilities and limitations. | |
To promote transparency and facilitate further research, we release the evaluation dataset, code, and a public leaderboard publicly for further analysis | |
and a general modular toolkit. We intend for ProteinBench to be a living benchmark for establishing a standardized, | |
in-depth evaluation framework for protein foundation models, driving their development and application while fostering collaboration within the field. | |
## [Paper](https://www.arxiv.org/pdf/2409.06744) | [Website](https://proteinbench.github.io/) | |
""" | |
def convert_to_float(df, start_col_idx=2): | |
columns = df.columns | |
for col in columns[start_col_idx:]: | |
df[col] = df[col].astype('float') | |
return df | |
def assign_rank_and_get_sorted_csv(src_csv_path, tag_csv_path, ignore_num=0): | |
src_csv = pd.read_csv(src_csv_path) | |
float_csv = convert_to_float(copy.deepcopy(src_csv), start_col_idx=1) | |
tag_csv = pd.read_csv(tag_csv_path) | |
rank_csv = pd.DataFrame() | |
float_csv = float_csv[ignore_num:] | |
for col in tag_csv.columns: | |
tag = int(tag_csv[col].iloc[0]) | |
if tag == 0: | |
continue | |
float_csv[col] *= tag | |
float_csv[col] = float_csv[col].fillna(value=1e12) | |
rank_csv[col] = float_csv[col].rank(method='min') * abs(tag) | |
rank_csv['__sum_of_ranks'] = rank_csv.sum(axis=1) | |
src_csv.insert(loc=0, column='Rank', value=-1 * np.ones(len(src_csv))) | |
src_csv.loc[list(range(ignore_num, len(src_csv))), 'Rank'] = rank_csv['__sum_of_ranks'].rank(method='min') | |
sorted_csv = src_csv.sort_values(by=["Rank"]) | |
if ignore_num >0 : | |
sorted_csv.loc[list(range(ignore_num)),'Rank'] = [np.nan] * ignore_num | |
return sorted_csv | |
# ### Space initialisation | |
demo = gr.Blocks(css=custom_css) | |
with demo: | |
with gr.Row(): | |
with gr.Column(scale=6): | |
gr.Markdown(TITLE) | |
with gr.Row(): | |
with gr.Column(scale=6): | |
gr.Markdown(INTRO_TEXT) | |
with gr.Column(scale=1): | |
gr.HTML(f'<img src="data:image/jpeg;base64,{main_logo}" style="width:16em;vertical-align: middle"/>') | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("π Inverse Folding Leaderboard", elem_id='inverse-folding-table', id=0,): | |
with gr.Row(): | |
inverse_folding_csv = assign_rank_and_get_sorted_csv('data_link/inverse_folding.csv', 'data_rank/inverse_folding.csv') | |
inverse_folding_table = gr.components.DataFrame( | |
value=convert_to_float(inverse_folding_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=inverse_folding_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(inverse_folding_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Structure Design Leaderboard", elem_id='structure-design-table', id=1,): | |
with gr.Row(): | |
structure_design_csv = assign_rank_and_get_sorted_csv('data_link/structure_design.csv','data_rank/structure_design.csv', ignore_num=1) | |
structure_design_table = gr.components.DataFrame( | |
value=convert_to_float(structure_design_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=structure_design_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(structure_design_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Sequence Design Leaderboard", elem_id='sequence-design-table', id=2,): | |
with gr.Row(): | |
sequence_design_csv = assign_rank_and_get_sorted_csv('data_link/sequence_design.csv', 'data_rank/sequence_design.csv', ignore_num=1) | |
sequence_design_table = gr.components.DataFrame( | |
value=convert_to_float(sequence_design_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=sequence_design_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(sequence_design_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Sequence-Structure Co-Design Leaderboard", elem_id='co-design-table', id=3,): | |
with gr.Row(): | |
co_design_csv = assign_rank_and_get_sorted_csv('data_link/co_design.csv', 'data_rank/co_design.csv', ignore_num=1) | |
co_design_table = gr.components.DataFrame( | |
value=convert_to_float(co_design_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=co_design_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(co_design_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Motif Scaffolding Leaderboard", elem_id='motif-scaffolding-table', id=4,): | |
with gr.Row(): | |
motif_scaffolding_csv = assign_rank_and_get_sorted_csv('data_link/motif_scaffolding.csv', 'data_rank/motif_scaffolding.csv') | |
motif_scaffolding_table = gr.components.DataFrame( | |
value=convert_to_float(motif_scaffolding_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=motif_scaffolding_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(motif_scaffolding_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Antibody Design Leaderboard", elem_id='antibody-design-table', id=5,): | |
with gr.Row(): | |
antibody_design_csv = assign_rank_and_get_sorted_csv('data_link/antibody_design.csv', 'data_rank/antibody_design.csv', ignore_num=1) | |
antibody_design_table = gr.components.DataFrame( | |
value=convert_to_float(antibody_design_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=antibody_design_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(antibody_design_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Protein Folding Leaderboard", elem_id='protein-folding-table', id=6,): | |
with gr.Row(): | |
protein_folding_csv = assign_rank_and_get_sorted_csv('data_link/protein_folding.csv', 'data_rank/protein_folding.csv') | |
protein_folding_table = gr.components.DataFrame( | |
value=convert_to_float(protein_folding_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=protein_folding_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(protein_folding_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Multi-State Prediction (BPTI) Leaderboard", elem_id='multi-state-prediction-bpti-table', id=7,): | |
with gr.Row(): | |
multi_state_prediction_csv = assign_rank_and_get_sorted_csv('data_link/multi_state_prediction_bpti.csv', 'data_rank/multi_state_prediction_bpti.csv') | |
multi_state_prediction_table = gr.components.DataFrame( | |
value=convert_to_float(multi_state_prediction_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=multi_state_prediction_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(multi_state_prediction_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Multi-State Prediction (apo-holo) Leaderboard", elem_id='multi-state-prediction-apo-table', id=8,): | |
with gr.Row(): | |
conformation_prediction_csv = assign_rank_and_get_sorted_csv('data_link/multi_state_prediction_apo.csv', 'data_rank/multi_state_prediction_apo.csv', ignore_num=1) | |
conformation_prediction_table = gr.components.DataFrame( | |
value=convert_to_float(conformation_prediction_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=conformation_prediction_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(conformation_prediction_csv.columns)-1) * ['number'], | |
) | |
with gr.TabItem("π Distribution Prediction Leaderboard", elem_id='distribution-prediction-table', id=9,): | |
with gr.Row(): | |
distribution_prediction_csv = assign_rank_and_get_sorted_csv('data_link/distribution_prediction.csv', 'data_rank/distribution_prediction.csv', ignore_num=2) | |
distribution_prediction_table = gr.components.DataFrame( | |
value=convert_to_float(distribution_prediction_csv).values, | |
# height=99999, | |
interactive=False, | |
headers=distribution_prediction_csv.columns.to_list(), | |
datatype=['number', 'markdown'] + (len(distribution_prediction_csv.columns)-1) * ['number'], | |
) | |
with gr.Row(): | |
with gr.Accordion("π Citation", open=True): | |
citation_button = gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
lines=9, | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(restart_space, "interval", seconds=1800) | |
scheduler.start() | |
demo.queue(default_concurrency_limit=40).launch() |