Spaces:
Running
Running
import pandas as pd | |
import plotly.express as px | |
import requests | |
import json | |
import gradio as gr | |
from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME | |
from src.leaderboard_utils import get_github_data | |
def plotly_plot(df: pd.DataFrame, list_op: list, list_co: list, | |
show_all: list, show_names: list, show_legend: list, | |
mobile_view: list): | |
""" | |
Takes in a list of models for a plotly plot | |
Args: | |
df: A dummy dataframe of latest version | |
list_op: The list of open source models to show in the plot, updated from frontend | |
list_co: The list of commercial models to show in the plot, updated from frontend | |
show_all: Either [] or ["Show All Models"] - toggle view to plot all models | |
show_names: Either [] or ["Show Names"] - toggle view to show model names on plot | |
show_legend: Either [] or ["Show Legend"] - toggle view to show legend on plot | |
mobile_view: Either [] or ["Mobile View"] - toggle view to for smaller screens | |
Returns: | |
Fig: plotly figure of % played v/s quality score | |
""" | |
LIST = list_op + list_co | |
# Get list of all models and append short names column to df | |
list_columns = list(df.columns) | |
ALL_LIST = list(df[list_columns[0]].unique()) | |
short_names = label_map(ALL_LIST) | |
list_short_names = list(short_names.values()) | |
df["Short"] = list_short_names | |
if show_all: | |
LIST = ALL_LIST | |
# Filter dataframe based on the provided list of models | |
df = df[df[list_columns[0]].isin(LIST)] | |
if show_names: | |
fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0], | |
color_discrete_map={"category1": "blue", "category2": "red"}, | |
hover_name=list_columns[0], template="plotly_white", text="Short") | |
fig.update_traces(textposition='top center') | |
else: | |
fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0], | |
color_discrete_map={"category1": "blue", "category2": "red"}, | |
hover_name=list_columns[0], template="plotly_white") | |
if not show_legend: | |
fig.update_layout(showlegend=False) | |
fig.update_layout( | |
xaxis_title='% Played', | |
yaxis_title='Quality Score', | |
title='Overview of benchmark results', | |
height=1000 | |
) | |
fig.update_xaxes(range=[-5, 105]) | |
fig.update_yaxes(range=[-5, 105]) | |
if mobile_view: | |
fig.update_layout(height=300) | |
if mobile_view and show_legend: | |
fig.update_layout(height=450) | |
fig.update_layout(legend=dict( | |
yanchor="bottom", | |
y=-5.52, | |
xanchor="left", | |
x=0.01 | |
)) | |
fig.update_layout( | |
xaxis_title="", | |
yaxis_title="", | |
title="% Played v/s Quality Score" | |
) | |
return fig | |
def shorten_model_name(full_name): | |
# Split the name into parts | |
parts = full_name.split('-') | |
# Process the name parts to keep only the parts with digits (model sizes and versions) | |
short_name_parts = [part for part in parts if any(char.isdigit() for char in part)] | |
if len(parts) == 1: | |
short_name = ''.join(full_name[0:min(3, len(full_name))]) | |
else: | |
# Join the parts to form the short name | |
short_name = '-'.join(short_name_parts) | |
# Remove any leading or trailing hyphens | |
short_name = full_name[0] + '-' + short_name.strip('-') | |
return short_name | |
def label_map(model_list: list) -> dict: | |
""" | |
Generate a map from long names to short names, to plot them in frontend graph | |
Define the short names in src/assets/text_content.py | |
Args: | |
model_list: A list of long model names | |
Returns: | |
short_name: A dict from long to short name | |
""" | |
short_names = {} | |
for model_name in model_list: | |
if model_name in SHORT_NAMES: | |
short_name = SHORT_NAMES[model_name] | |
else: | |
short_name = shorten_model_name(model_name) | |
# Define the short name and indicate both models are same | |
short_names[model_name] = short_name | |
return short_names | |
def split_models(model_list: list): | |
""" | |
Split the models into open source and commercial | |
""" | |
open_models = [] | |
commercial_models = [] | |
# Load model registry data from main repo | |
model_registry_url = "https://raw.githubusercontent.com/clp-research/clembench/main/backends/model_registry.json" | |
response = requests.get(model_registry_url) | |
if response.status_code == 200: | |
json_data = json.loads(response.text) | |
for model_name in model_list: | |
for entry in json_data: | |
if entry["model_name"] == model_name: | |
open_model = entry["open_weight"] | |
if open_model: | |
open_models.append(model_name) | |
else: | |
commercial_models.append(model_name) | |
break | |
else: | |
print(f"Failed to read JSON file: Status Code : {response.status_code}") | |
open_models.sort(key=lambda o: o.upper()) | |
commercial_models.sort(key=lambda c: c.upper()) | |
# Add missing model from the model_registry | |
if "dolphin-2.5-mixtral-8x7b" in model_list: | |
open_models.append("dolphin-2.5-mixtral-8x7b") | |
return open_models, commercial_models | |
""" | |
Update Functions, for when the leaderboard selection changes | |
""" | |
def update_open_models(leaderboard: str = TEXT_NAME): | |
""" | |
Change the checkbox group of Open Models based on the leaderboard selected | |
Args: | |
leaderboard: Selected leaderboard from the frontend [Default - Text Leaderboard] | |
Return: | |
Updated checkbox group for Open Models, based on the leaderboard selected | |
""" | |
github_data = get_github_data() | |
leaderboard_data = github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0] | |
models = leaderboard_data.iloc[:, 0].unique().tolist() | |
open_models, commercial_models = split_models(models) | |
return gr.CheckboxGroup( | |
open_models, | |
value=[], | |
elem_id="value-select-1", | |
interactive=True, | |
) | |
def update_closed_models(leaderboard: str = TEXT_NAME): | |
""" | |
Change the checkbox group of Closed Models based on the leaderboard selected | |
Args: | |
leaderboard: Selected leaderboard from the frontend [Default - Text Leaderboard] | |
Return: | |
Updated checkbox group for Closed Models, based on the leaderboard selected | |
""" | |
github_data = get_github_data() | |
leaderboard_data = github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0] | |
models = leaderboard_data.iloc[:, 0].unique().tolist() | |
open_models, commercial_models = split_models(models) | |
return gr.CheckboxGroup( | |
commercial_models, | |
value=[], | |
elem_id="value-select-2", | |
interactive=True, | |
) | |
def get_plot_df(leaderboard: str = TEXT_NAME) -> pd.DataFrame: | |
""" | |
Get the DataFrame for plotting based on the selected leaderboard. | |
Args: | |
leaderboard: Selected leaderboard. | |
Returns: | |
DataFrame with model data. | |
""" | |
github_data = get_github_data() | |
return github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0] | |
""" | |
Reset Functions for when the Leaderboard selection changes | |
""" | |
def reset_show_all(): | |
return gr.CheckboxGroup( | |
["Select All Models"], | |
label="Show plot for all models π€", | |
value=[], | |
elem_id="value-select-3", | |
interactive=True, | |
) | |
def reset_show_names(): | |
return gr.CheckboxGroup( | |
["Show Names"], | |
label="Show names of models on the plot π·οΈ", | |
value=[], | |
elem_id="value-select-4", | |
interactive=True, | |
) | |
def reset_show_legend(): | |
return gr.CheckboxGroup( | |
["Show Legend"], | |
label="Show legend on the plot π‘", | |
value=[], | |
elem_id="value-select-5", | |
interactive=True, | |
) | |
def reset_mobile_view(): | |
return gr.CheckboxGroup( | |
["Mobile View"], | |
label="View plot on smaller screens π±", | |
value=[], | |
elem_id="value-select-6", | |
interactive=True, | |
) | |
if __name__ == '__main__': | |
mm_model_list = ['gpt-4o-2024-05-13', 'gpt-4-1106-vision-preview', 'claude-3-opus-20240229', 'gemini-1.5-pro-latest', | |
'gemini-1.5-flash-latest', 'llava-v1.6-34b-hf', 'llava-v1.6-vicuna-13b-hf', 'idefics-80b-instruct', | |
'llava-1.5-13b-hf', 'idefics-9b-instruct'] | |
text_model_list = ['vicuna-33b-v1.3', 'gpt-4-0125-preview', 'gpt-4-turbo-2024-04-09', 'claude-3-5-sonnet-20240620', 'gpt-4-1106-preview', | |
'gpt-4-0613', 'gpt-4o-2024-05-13', 'claude-3-opus-20240229', 'gemini-1.5-pro-latest', | |
'Meta-Llama-3-70B-Instruct-hf', 'claude-2.1', 'gemini-1.5-flash-latest', 'claude-3-sonnet-20240229', | |
'Qwen1.5-72B-Chat', 'mistral-large-2402', 'gpt-3.5-turbo-0125', 'gemini-1.0-pro', 'command-r-plus', 'openchat_3.5', | |
'claude-3-haiku-20240307', 'sheep-duck-llama-2-70b-v1.1', 'Meta-Llama-3-8B-Instruct-hf', 'openchat-3.5-1210', | |
'WizardLM-70b-v1.0', 'openchat-3.5-0106', 'Qwen1.5-14B-Chat', 'mistral-medium-2312', 'Qwen1.5-32B-Chat', | |
'codegemma-7b-it', 'dolphin-2.5-mixtral-8x7b', 'CodeLlama-34b-Instruct-hf', 'command-r', 'gemma-1.1-7b-it', | |
'SUS-Chat-34B', 'Mixtral-8x22B-Instruct-v0.1', 'tulu-2-dpo-70b', 'Nous-Hermes-2-Mixtral-8x7B-SFT', | |
'WizardLM-13b-v1.2', 'Mistral-7B-Instruct-v0.2', 'Yi-34B-Chat', 'Mixtral-8x7B-Instruct-v0.1', | |
'Mistral-7B-Instruct-v0.1', 'Yi-1.5-34B-Chat', 'vicuna-13b-v1.5', 'Yi-1.5-6B-Chat', 'Starling-LM-7B-beta', | |
'sheep-duck-llama-2-13b', 'Yi-1.5-9B-Chat', 'gemma-1.1-2b-it', 'Qwen1.5-7B-Chat', 'gemma-7b-it', | |
'llama-2-70b-chat-hf', 'Qwen1.5-0.5B-Chat', 'Qwen1.5-1.8B-Chat'] | |
om, cm = split_models(mm_model_list) | |
print("Open") | |
print(om) | |
print("Closed") | |
print(cm) | |