Spaces:
Sleeping
Sleeping
import marimo | |
__generated_with = "0.9.14" | |
app = marimo.App(width="medium") | |
def __(): | |
import marimo as mo | |
import duckdb | |
import pandas | |
import numpy | |
import altair as alt | |
import plotly.express as px | |
mo.md("# 🤗 Hub Model Tree") | |
return alt, duckdb, mo, numpy, pandas, px | |
def __(mo): | |
mo.md("""This is powered by the [Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) dataset which you can query via the [SQL Console](https://huggingface.co/datasets/cfahlgren1/hub-stats?sql_console=true). The model tree metric is where a model tags a parent model as a `base_model`. The `hub-stats` dataset gets updated daily. Try it out by putting an organization or model author in search box and hit enter.""") | |
return | |
def __(duckdb): | |
duckdb.sql("CREATE VIEW models as SELECT * FROM 'hf://datasets/cfahlgren1/hub-stats/models.parquet'") | |
return (models,) | |
def __(mo): | |
author_input = mo.ui.text(placeholder="Search...", label="Author") | |
ctes = """ | |
WITH author_models AS ( | |
SELECT id | |
FROM models | |
WHERE author = '{}' | |
), | |
model_tags AS ( | |
SELECT | |
id, | |
UNNEST(tags) AS tag | |
FROM models | |
) | |
""" | |
def get_model_children_counts(author: str) -> str: | |
return f""" | |
{ctes.format(author)} | |
SELECT | |
am.id as parent_model_id, | |
COUNT(DISTINCT m.id) as num_direct_children | |
FROM author_models am | |
INNER JOIN model_tags m | |
ON m.tag = 'base_model:' || am.id | |
GROUP BY am.id | |
ORDER BY num_direct_children DESC; | |
""" | |
def get_total_childen_count(author: str) -> str: | |
return f""" | |
{ctes.format(author)} | |
SELECT | |
COUNT(DISTINCT m.id) as num_direct_children | |
FROM author_models am | |
LEFT JOIN model_tags m | |
ON m.tag = 'base_model:' || am.id | |
""" | |
return ( | |
author_input, | |
ctes, | |
get_model_children_counts, | |
get_total_childen_count, | |
) | |
def __(mo): | |
mo.md("## Search by Author") | |
return | |
def __(author_input, mo): | |
mo.vstack([author_input, mo.md("_ex: meta-llama, google, mistralai, Qwen_")]) | |
return | |
def __(author_input, duckdb, get_total_childen_count, mo): | |
result = duckdb.sql(get_total_childen_count(author_input.value)).fetchall() | |
mo.vstack([mo.md("### Direct Child Models"), mo.md(f"_The number of models that have tagged a {author_input.value} model as a `base_model`_"), mo.stat(result[0][0])]) | |
return (result,) | |
def __(author_input, duckdb, get_model_children_counts): | |
df = duckdb.sql(get_model_children_counts(author_input.value)).fetchdf() | |
df | |
return (df,) | |
def __(df, mo, px): | |
_plot = px.bar( | |
df, x="parent_model_id", y="num_direct_children", log_y=True | |
) | |
mo.ui.plotly(_plot) | |
return | |
if __name__ == "__main__": | |
app.run() | |