Spaces:
Running
Running
import gradio as gr | |
import requests | |
import pandas as pd | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import plotly.graph_objects as go | |
class DownloadsTracker: | |
def __init__(self): | |
self.all_packages = {} | |
self.current_packages = [] | |
def fetch_if_needed(self, packages): | |
errors = [] | |
for pkg in packages: | |
if pkg not in self.all_packages: | |
try: | |
response = requests.get(f"https://pypistats.org/api/packages/{pkg}/overall") | |
if response.status_code != 200: | |
errors.append(f"Package not found: {pkg}") | |
continue | |
data = response.json()["data"] | |
df = pd.DataFrame([ | |
{"date": d["date"], "downloads": d["downloads"]} | |
for d in data if d["category"] == "without_mirrors" | |
]) | |
df["date"] = pd.to_datetime(df["date"]) | |
df = df.sort_values("date") | |
df["cumulative_downloads"] = df["downloads"].cumsum() | |
# Add weekly downloads | |
df["weekly_downloads"] = df["downloads"].rolling(window=7).mean() | |
self.all_packages[pkg] = df | |
except: | |
errors.append(f"Error fetching {pkg}") | |
if len(errors) > 0: | |
return "\n".join(errors) | |
return None | |
def plot(self, use_log_scale): | |
fig = make_subplots(rows=2, cols=1, subplot_titles=("Cumulative Downloads (restarted at 0)", "Weekly Downloads (7 days rolling sum)")) | |
colors = px.colors.qualitative.Pastel # Built-in color sequence | |
for i, pkg in enumerate(self.current_packages): | |
if pkg in self.all_packages: | |
df = self.all_packages[pkg] | |
color = colors[i % len(colors)] | |
fig.add_trace( | |
go.Scatter(x=df["date"], y=df["cumulative_downloads"], | |
name=pkg, line=dict(color=color)), | |
row=1, col=1 | |
) | |
fig.add_trace( | |
go.Scatter(x=df["date"], y=df["weekly_downloads"], | |
name=pkg, line=dict(color=color), showlegend=False), | |
row=2, col=1 | |
) | |
if use_log_scale: | |
fig.update_yaxes(type="log", row=1, col=1) | |
fig.update_yaxes(type="log", row=2, col=1) | |
else: | |
fig.update_yaxes(type="linear", row=1, col=1) | |
fig.update_yaxes(type="linear", row=2, col=1) | |
fig.update_layout( | |
height=800, | |
font=dict(size=16), | |
title_font=dict(size=20), | |
legend_font=dict(size=16) | |
) | |
return fig | |
def render(self, package_list, use_log_scale): | |
package_list = [p.strip() for p in package_list.split(",") if p.strip()] | |
self.current_packages = package_list | |
errors = self.fetch_if_needed(package_list) | |
return self.plot(use_log_scale), errors, gr.update(visible=errors is not None) | |
tracker = DownloadsTracker() | |
css = """ | |
#textbox_id textarea {color: red} | |
#textbox_id span {background-color: red} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
packages = gr.Textbox("transformers, accelerate", label="Package names (comma-separated)") | |
log_scale = gr.Checkbox(label="Use logarithmic scale on Y-axis", value=False) | |
error_box = gr.Textbox(label="Errors:", interactive=False, visible=False, elem_id="textbox_id") | |
render_btn = gr.Button("Render") | |
plot = gr.Plot() | |
render_btn.click(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box]) | |
log_scale.change(tracker.plot, inputs=[log_scale], outputs=[plot]) | |
demo.load(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box]) | |
demo.launch() |