File size: 4,036 Bytes
f50f9f2
cb966df
 
 
ceded7f
 
f50f9f2
cb966df
 
 
b3dc0a7
cb966df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceded7f
 
cb966df
 
 
 
 
 
f50f9f2
b3dc0a7
53d2375
b749eb1
b5dc7a6
b749eb1
cb966df
c87fa05
b749eb1
ceded7f
 
b749eb1
 
ceded7f
 
 
 
b749eb1
 
ceded7f
 
7f620e6
b3dc0a7
ceded7f
 
b3dc0a7
ceded7f
 
 
43675ef
 
 
 
 
 
b3dc0a7
cb966df
ceded7f
 
 
 
 
 
cb966df
 
6124650
 
 
 
b3dc0a7
ceded7f
a4e7be3
ceded7f
6124650
cb966df
 
b3dc0a7
 
cb966df
ceded7f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import requests
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

class DownloadsTracker:
    def __init__(self):
        self.all_packages = {}
        self.current_packages = []
    
    def fetch_if_needed(self, packages):
        errors = []
        for pkg in packages:
            if pkg not in self.all_packages:
                try:
                    response = requests.get(f"https://pypistats.org/api/packages/{pkg}/overall")
                    if response.status_code != 200:
                        errors.append(f"Package not found: {pkg}")
                        continue
                    
                    data = response.json()["data"]
                    df = pd.DataFrame([
                        {"date": d["date"], "downloads": d["downloads"]} 
                        for d in data if d["category"] == "without_mirrors"
                    ])
                    df["date"] = pd.to_datetime(df["date"])
                    df = df.sort_values("date")
                    df["cumulative_downloads"] = df["downloads"].cumsum()
                    # Add weekly downloads
                    df["weekly_downloads"] = df["downloads"].rolling(window=7).mean()
                    self.all_packages[pkg] = df
                except:
                    errors.append(f"Error fetching {pkg}")
        if len(errors) > 0:
            return "\n".join(errors)
        return None

    def plot(self, use_log_scale):
        fig = make_subplots(rows=2, cols=1, subplot_titles=("Cumulative Downloads (restarted at 0)", "Weekly Downloads (7 days rolling sum)"))
        
        colors = px.colors.qualitative.Pastel  # Built-in color sequence
        for i, pkg in enumerate(self.current_packages):
            if pkg in self.all_packages:
                df = self.all_packages[pkg]
                color = colors[i % len(colors)]
                
                fig.add_trace(
                    go.Scatter(x=df["date"], y=df["cumulative_downloads"], 
                              name=pkg, line=dict(color=color)),
                    row=1, col=1
                )
                
                fig.add_trace(
                    go.Scatter(x=df["date"], y=df["weekly_downloads"], 
                              name=pkg, line=dict(color=color), showlegend=False),
                    row=2, col=1
                )
        
        if use_log_scale:
            fig.update_yaxes(type="log", row=1, col=1)
            fig.update_yaxes(type="log", row=2, col=1)
        else:
            fig.update_yaxes(type="linear", row=1, col=1)
            fig.update_yaxes(type="linear", row=2, col=1)
            
        fig.update_layout(
            height=800,
            font=dict(size=16),
            title_font=dict(size=20),
            legend_font=dict(size=16)
        )
        return fig

    def render(self, package_list, use_log_scale):
        package_list = [p.strip() for p in package_list.split(",") if p.strip()]
        self.current_packages = package_list
        errors = self.fetch_if_needed(package_list)
        return self.plot(use_log_scale), errors, gr.update(visible=errors is not None)

tracker = DownloadsTracker()

css = """
#textbox_id textarea {color: red}
#textbox_id span {background-color: red}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    packages = gr.Textbox("transformers, accelerate", label="Package names (comma-separated)")
    log_scale = gr.Checkbox(label="Use logarithmic scale on Y-axis", value=False)
    error_box = gr.Textbox(label="Errors:", interactive=False, visible=False, elem_id="textbox_id")
    render_btn = gr.Button("Render")
    plot = gr.Plot()
    render_btn.click(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box])
    log_scale.change(tracker.plot, inputs=[log_scale], outputs=[plot])

    demo.load(tracker.render, inputs=[packages, log_scale], outputs=[plot, error_box, error_box])

demo.launch()