m-ric's picture
m-ric HF staff
Update app.py
53d2375 verified
import gradio as gr
import requests
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
class DownloadsTracker:
def __init__(self):
self.all_packages = {}
self.current_packages = []
def fetch_if_needed(self, packages):
errors = []
for pkg in packages:
if pkg not in self.all_packages:
try:
response = requests.get(f"https://pypistats.org/api/packages/{pkg}/overall")
if response.status_code != 200:
errors.append(f"Package not found: {pkg}")
continue
data = response.json()["data"]
df = pd.DataFrame([
{"date": d["date"], "downloads": d["downloads"]}
for d in data if d["category"] == "without_mirrors"
])
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date")
df["cumulative_downloads"] = df["downloads"].cumsum()
# Add weekly downloads
df["weekly_downloads"] = df["downloads"].rolling(window=7).mean()
self.all_packages[pkg] = df
except:
errors.append(f"Error fetching {pkg}")
if len(errors) > 0:
return "\n".join(errors)
return None
def plot(self, use_log_scale):
fig = make_subplots(rows=2, cols=1, subplot_titles=("Cumulative Downloads (restarted at 0)", "Weekly Downloads (7 days rolling sum)"))
colors = px.colors.qualitative.Pastel # Built-in color sequence
for i, pkg in enumerate(self.current_packages):
if pkg in self.all_packages:
df = self.all_packages[pkg]
color = colors[i % len(colors)]
fig.add_trace(
go.Scatter(x=df["date"], y=df["cumulative_downloads"],
name=pkg, line=dict(color=color)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=df["date"], y=df["weekly_downloads"],
name=pkg, line=dict(color=color), showlegend=False),
row=2, col=1
)
if use_log_scale:
fig.update_yaxes(type="log", row=1, col=1)
fig.update_yaxes(type="log", row=2, col=1)
else:
fig.update_yaxes(type="linear", row=1, col=1)
fig.update_yaxes(type="linear", row=2, col=1)
fig.update_layout(
height=800,
font=dict(size=16),
title_font=dict(size=20),
legend_font=dict(size=16)
)
return fig
def render(self, package_list, use_log_scale):
package_list = [p.strip() for p in package_list.split(",") if p.strip()]
self.current_packages = package_list
errors = self.fetch_if_needed(package_list)
return self.plot(use_log_scale), errors, gr.update(visible=errors is not None)
tracker = DownloadsTracker()
css = """
#textbox_id textarea {color: red}
#textbox_id span {background-color: red}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
packages = gr.Textbox("transformers, accelerate", label="Package names (comma-separated)")
log_scale = gr.Checkbox(label="Use logarithmic scale on Y-axis", value=False)
error_box = gr.Textbox(label="Errors:", interactive=False, visible=False, elem_id="textbox_id")
render_btn = gr.Button("Render")
plot = gr.Plot()
render_btn.click(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box])
log_scale.change(tracker.plot, inputs=[log_scale], outputs=[plot])
demo.load(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box])
demo.launch()