davanstrien's picture
davanstrien HF staff
fix imports
82b6a11
raw
history blame
6.15 kB
import os
from datetime import datetime, timedelta
import argilla as rg
import gradio as gr
import pandas as pd
import plotly.colors as colors
import plotly.graph_objects as go
from cachetools import TTLCache, cached
client = rg.Argilla(
api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY")
)
cache = TTLCache(maxsize=100, ttl=timedelta(minutes=10), timer=datetime.now)
@cached(cache)
def fetch_data(dataset_name: str, workspace: str):
return client.datasets(dataset_name, workspace=workspace)
def get_progress(dataset) -> dict:
records = list(dataset.records)
total_records = len(records)
annotated_records = len(
[record.status for record in records if record.status == "completed"]
)
progress = (annotated_records / total_records) * 100 if total_records > 0 else 0
return {
"total": total_records,
"annotated": annotated_records,
"progress": progress,
}
def get_leaderboard(dataset) -> dict:
user_annotations = {}
for record in dataset.records:
for response in record.responses:
user = response.user_id
retrieved_user = client.users(id=user)
user = retrieved_user.username
if user not in user_annotations:
user_annotations[user] = 0
user_annotations[user] += 1
print(user_annotations)
return user_annotations
def create_gauge_chart(progress):
fig = go.Figure(
go.Indicator(
mode="gauge+number+delta",
value=progress["progress"],
title={"text": "Dataset Annotation Progress", "font": {"size": 24}},
delta={"reference": 100, "increasing": {"color": "RebeccaPurple"}},
number={"font": {"size": 40}, "valueformat": ".1f", "suffix": "%"},
gauge={
"axis": {"range": [None, 100], "tickwidth": 1, "tickcolor": "darkblue"},
"bar": {"color": "deepskyblue"},
"bgcolor": "white",
"borderwidth": 2,
"bordercolor": "gray",
"steps": [
{"range": [0, progress["progress"]], "color": "royalblue"},
{"range": [progress["progress"], 100], "color": "lightgray"},
],
"threshold": {
"line": {"color": "red", "width": 4},
"thickness": 0.75,
"value": 100,
},
},
)
)
fig.update_layout(
annotations=[
dict(
text=(
f"Total records: {progress['total']}<br>"
f"Annotated: {progress['annotated']} ({progress['progress']:.1f}%)<br>"
f"Remaining: {progress['total'] - progress['annotated']} ({100 - progress['progress']:.1f}%)"
),
# x=0.5,
# y=-0.2,
showarrow=False,
xref="paper",
yref="paper",
font=dict(size=16),
)
],
)
fig.add_annotation(
text=(
f"Current Progress: {progress['progress']:.1f}% complete<br>"
f"({progress['annotated']} out of {progress['total']} records annotated)"
),
xref="paper",
yref="paper",
x=0.5,
y=1.1,
showarrow=False,
font=dict(size=18),
align="center",
)
return fig
def create_treemap(user_annotations, total_records):
sorted_users = sorted(user_annotations.items(), key=lambda x: x[1], reverse=True)
color_scale = colors.qualitative.Pastel + colors.qualitative.Set3
labels, parents, values, text, user_colors = [], [], [], [], []
for i, (user, contribution) in enumerate(sorted_users):
percentage = (contribution / total_records) * 100
labels.append(user)
parents.append("Annotations")
values.append(contribution)
text.append(f"{contribution} annotations<br>{percentage:.2f}%")
user_colors.append(color_scale[i % len(color_scale)])
labels.append("Annotations")
parents.append("")
values.append(total_records)
text.append(f"Total: {total_records} annotations")
user_colors.append("#FFFFFF")
fig = go.Figure(
go.Treemap(
labels=labels,
parents=parents,
values=values,
text=text,
textinfo="label+text",
hoverinfo="label+text+value",
marker=dict(colors=user_colors, line=dict(width=2)),
)
)
fig.update_layout(
title_text="User contributions to the total end dataset",
height=500,
margin=dict(l=10, r=10, t=50, b=10),
paper_bgcolor="#F0F0F0", # Light gray background
plot_bgcolor="#F0F0F0", # Light gray background
)
return fig
def update_dashboard():
dataset = fetch_data(os.getenv("DATASET_NAME"), os.getenv("WORKSPACE"))
progress = get_progress(dataset)
user_annotations = get_leaderboard(dataset)
gauge_chart = create_gauge_chart(progress)
treemap = create_treemap(user_annotations, progress["total"])
leaderboard_df = pd.DataFrame(
list(user_annotations.items()), columns=["User", "Annotations"]
)
leaderboard_df = leaderboard_df.sort_values(
"Annotations", ascending=False
).reset_index(drop=True)
return gauge_chart, treemap, leaderboard_df
with gr.Blocks() as demo:
gr.Markdown("# Argilla Dataset Dashboard")
with gr.Row():
gauge_output = gr.Plot(label="Overall Progress")
treemap_output = gr.Plot(label="User contributions")
with gr.Row():
leaderboard_output = gr.Dataframe(
label="Leaderboard", headers=["User", "Annotations"]
)
demo.load(
update_dashboard,
inputs=None,
outputs=[gauge_output, treemap_output, leaderboard_output],
)
gr.Button("Refresh").click(
update_dashboard,
inputs=None,
outputs=[gauge_output, treemap_output, leaderboard_output],
)
if __name__ == "__main__":
demo.launch()