File size: 3,164 Bytes
c12bd84 e3642ff c12bd84 e3642ff c12bd84 e3642ff c12bd84 e3642ff c12bd84 e3642ff c12bd84 43b4e29 e3642ff 43b4e29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import streamlit as st
import pandas as pd
import os
import fnmatch
import json
class MultiURLData:
def __init__(self):
self.data = self.process_data()
def process_data(self):
dataframes = []
def find_files(directory, pattern):
for root, dirs, files in os.walk(directory):
for basename in files:
if fnmatch.fnmatch(basename, pattern):
filename = os.path.join(root, basename)
yield filename
for filename in find_files('results', 'results*.json'):
model_name = filename.split('/')[2]
with open(filename) as f:
data = json.load(f)
df = pd.DataFrame(data['results']).T
df = df.rename(columns={'acc': model_name})
df.index = df.index.str.replace('hendrycksTest-', '')
df.index = df.index.str.replace('harness\\|', '')
dataframes.append(df[[model_name]])
data = pd.concat(dataframes, axis=1)
data = data.transpose()
data['Model Name'] = data.index
cols = data.columns.tolist()
cols = cols[-1:] + cols[:-1]
data = data[cols]
return data
def get_data(self, selected_models):
filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
return filtered_data
data_provider = MultiURLData()
st.title('Leaderboard')
# TODO actually use these checkboxes as filters
## Desired behavior
## model and column selection is hidden by default
## when the user clicks the checkbox, the model and column selection appears
filters = st.checkbox('Add filters')
# Create checkboxes for each column
selected_columns = st.multiselect(
'Select Columns',
data_provider.data.columns.tolist(),
default=data_provider.data.columns.tolist()
)
selected_models = st.multiselect(
'Select Models',
data_provider.data['Model Name'].tolist(),
default=data_provider.data['Model Name'].tolist()
)
# Get the filtered data and display it in a table
filtered_data = data_provider.get_data(selected_models)
st.dataframe(filtered_data)
#TODO fix this plot. currently has an error
# Create a plot with new data
df = pd.DataFrame({
'Model': list(filtered_data['Model Name']),
'harness|arc:challenge|25_rank': list(filtered_data['harness|arc:challenge|25_rank']),
'harness|moral_scenarios|5_rank': list(filtered_data['harness|moral_scenarios|5_rank']),
})
# Calculate color column
df['color'] = 'purple'
df.loc[df['harness|moral_scenarios|5_rank'] < df['harness|arc:challenge|25_rank'], 'color'] = 'red'
df.loc[df['harness|moral_scenarios|5_rank'] > df['harness|arc:challenge|25_rank'], 'color'] = 'blue'
# Create the scatter plot
fig = px.scatter(df, x='harness|arc:challenge|25_rank', y='harness|moral_scenarios|5_rank', color='color', hover_data=['Model'])
fig.update_layout(showlegend=False, # hide legend
xaxis = dict(autorange="reversed"), # reverse X-axis
yaxis = dict(autorange="reversed")) # reverse Y-axis
# Show the plot in Streamlit
st.plotly_chart(fig)
|