Spaces:
Running
Running
import os | |
from plotly import graph_objects as go | |
import pandas as pd | |
## Evaluation Graphs | |
# Load the data | |
all_eval_results = {} | |
for fname in os.listdir("data/txt360_eval"): | |
if fname.endswith(".csv"): | |
metric_name = fname.replace("CKPT Eval - ", "").replace(".csv", "") | |
all_eval_results[metric_name] = {} | |
# with open(os.path.join("data/txt360_eval", fname)) as f: | |
df = pd.read_csv(os.path.join("data/txt360_eval", fname)) | |
# slimpajama_res = df.iloc[2:, 2].astype(float).fillna(0.0) # slimpajama | |
fineweb_res = df.iloc[2:, 1].astype(float).fillna(method="bfill") # fineweb | |
txt360_base = df.iloc[2:, 2].astype(float).fillna(method="bfill") # txt360-dedup-only | |
txt360_web_up = df.iloc[2:, 3].astype(float).fillna(method="bfill") # txt360-web-only-upsampled | |
txt360_all_up_stack = df.iloc[2:, 4].astype(float).fillna(method="bfill") # txt360-all-upsampled + stackv2 | |
# each row is 20B tokens. | |
# all_eval_results[metric_name]["slimpajama"] = slimpajama_res | |
all_eval_results[metric_name]["fineweb"] = fineweb_res | |
all_eval_results[metric_name]["txt360-dedup-only"] = txt360_base | |
all_eval_results[metric_name]["txt360-web-only-upsampled"] = txt360_web_up | |
all_eval_results[metric_name]["txt360-all-upsampled + stackv2"] = txt360_all_up_stack | |
all_eval_results[metric_name]["token"] = [20 * i for i in range(len(fineweb_res))] | |
# Eval Result Plots | |
all_eval_res_figs = {} | |
for metric_name, res in all_eval_results.items(): | |
fig_res = go.Figure() | |
# Add lines | |
fig_res.add_trace(go.Scatter( | |
x=all_eval_results[metric_name]["token"], | |
y=all_eval_results[metric_name]["fineweb"], | |
mode='lines', name='FineWeb' | |
)) | |
fig_res.add_trace(go.Scatter( | |
x=all_eval_results[metric_name]["token"], | |
y=all_eval_results[metric_name]["txt360-web-only-upsampled"], | |
mode='lines', name='TxT360 - CC Data Upsampled' | |
)) | |
fig_res.add_trace(go.Scatter( | |
x=all_eval_results[metric_name]["token"], | |
y=all_eval_results[metric_name]["txt360-dedup-only"], | |
mode='lines', name='TxT360 - CC Data Dedup' | |
)) | |
fig_res.add_trace(go.Scatter( | |
x=all_eval_results[metric_name]["token"], | |
y=all_eval_results[metric_name]["txt360-all-upsampled + stackv2"], | |
mode='lines', name='TxT360 - Full Upsampled + Stack V2' | |
)) | |
# Update layout | |
fig_res.update_layout( | |
title=f"{metric_name} Performance", | |
title_x=0.5, # Centers the title | |
xaxis_title="Billion Tokens", | |
yaxis_title=metric_name, | |
legend_title="Dataset", | |
) | |
all_eval_res_figs[metric_name] = fig_res |