import gradio as gr
import pandas as pd
from cachetools import TTLCache, cached
from huggingface_hub import list_models
from toolz import groupby
from tqdm.auto import tqdm


@cached(TTLCache(maxsize=10, ttl=60 * 60 * 3))
def get_all_models():
    models = list(
        tqdm(
            iter(list_models(cardData=True, limit=None, sort="downloads", direction=-1))
        )
    )
    models = [model for model in models if model is not None]
    return [
        model for model in models if model.downloads > 1
    ]  # filter out models with 0 downloads


def has_base_model_info(model):
    try:
        if card_data := model.cardData:
            if base_model := card_data.get("base_model"):
                if isinstance(base_model, str):
                    return True
    except AttributeError:
        return False
    return False


grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models())


def produce_summary():
    return f"""{len(grouped_by_has_base_model_info.get(True)):,} models have base model info. 
            {len(grouped_by_has_base_model_info.get(False)):,} models don't have base model info.
            Currently {round(len(grouped_by_has_base_model_info.get(True))/len(get_all_models())*100,2)}% of models have base model info."""


models_with_base_model_info = grouped_by_has_base_model_info.get(True)
base_models = [
    model.cardData.get("base_model") for model in models_with_base_model_info
]
df = pd.DataFrame(
    pd.DataFrame({"base_model": base_models}).value_counts()
).reset_index()
df_with_org = df.copy(deep=True)
pipeline_tags = [x.pipeline_tag for x in models_with_base_model_info]
# sort pipeline tags alphabetically
unique_pipeline_tags = list(
    {x.pipeline_tag for x in models_with_base_model_info if x.pipeline_tag is not None}
)


def parse_org(hub_id):
    parts = hub_id.split("/")
    if len(parts) == 2:
        return parts[0] if parts[0] != "." else None
    else:
        return "huggingface"


def render_model_hub_link(hub_id):
    link = f"https://huggingface.co/{hub_id}"
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'


df_with_org["org"] = df_with_org["base_model"].apply(parse_org)
df_with_org = df_with_org.dropna(subset=["org"])

grouped_by_base_model = groupby(
    lambda x: x.cardData.get("base_model"), models_with_base_model_info
)
print(df.columns)
all_base_models = df["base_model"].to_list()


def get_grandchildren(base_model):
    grandchildren = []
    for model in tqdm(grouped_by_base_model[base_model]):
        model_id = model.modelId
        grandchildren.extend(grouped_by_base_model.get(model_id, []))
    return grandchildren


def return_models_for_base_model(base_model):
    models = grouped_by_base_model.get(base_model)
    # sort models by downloads
    models = sorted(models, key=lambda x: x.downloads, reverse=True)
    results = ""
    results += (
        "## Models fine-tuned from"
        f" [`{base_model}`](https://huggingface.co/{base_model}) \n\n"
    )
    results += f"`{base_model}` has {len(models)} children\n\n"
    total_download_number = sum(model.downloads for model in models)
    results += (
        f"`{base_model}`'s children have been"
        f" downloaded {total_download_number:,} times\n\n"
    )
    grandchildren = get_grandchildren(base_model)
    number_of_grandchildren = len(grandchildren)
    results += f"`{base_model}` has {number_of_grandchildren} grandchildren\n\n"
    grandchildren_download_count = sum(model.downloads for model in grandchildren)
    results += (
        f"`{base_model}`'s grandchildren have been"
        f" downloaded {grandchildren_download_count:,} times\n\n"
    )
    results += f"Including grandchildren, `{base_model}` has {number_of_grandchildren + len(models):,} descendants\n\n"
    results += f"Including grandchildren, `{base_model}`'s descendants have been downloaded {grandchildren_download_count + total_download_number:,} times\n\n"
    results += "### Children models \n\n"
    for model in models:
        url = f"https://huggingface.co/{model.modelId}"
        results += (
            f"- [{model.modelId}]({url}) | number of downloads {model.downloads:,}"
            + "\n\n"
        )
    return results


def return_base_model_popularity(pipeline=None):
    df_with_pipeline_info = (
        pd.DataFrame({"base_model": base_models, "pipeline": pipeline_tags})
        .value_counts()
        .reset_index()
    )

    if pipeline is not None:
        df_with_pipeline_info = df_with_pipeline_info[
            df_with_pipeline_info["pipeline"] == pipeline
        ]
    keep_columns = ["base_model", "count"]
    df_with_pipeline_info["base_model"] = df_with_pipeline_info["base_model"].apply(
        render_model_hub_link
    )
    return df_with_pipeline_info[keep_columns].head(50)


def return_base_model_popularity_by_org(pipeline=None):
    referenced_base_models = [
        f"[`{model}`](https://huggingface.co/{model})" for model in base_models
    ]
    df_with_pipeline_info = pd.DataFrame(
        {"base_model": base_models, "pipeline": pipeline_tags}
    )
    df_with_pipeline_info["org"] = df_with_pipeline_info["base_model"].apply(parse_org)
    df_with_pipeline_info["org"] = df_with_pipeline_info["org"].apply(
        render_model_hub_link
    )
    df_with_pipeline_info = df_with_pipeline_info.dropna(subset=["org"])
    df_with_org = df_with_pipeline_info.copy(deep=True)
    if pipeline is not None:
        df_with_org = df_with_pipeline_info[df_with_org["pipeline"] == pipeline]
    df_with_org = df_with_org.drop(columns=["pipeline"])
    df_with_org = pd.DataFrame(df_with_org.value_counts())
    return pd.DataFrame(
        df_with_org.groupby("org")["count"]
        .sum()
        .sort_values(ascending=False)
        .reset_index()
        .head(50)
    )


with gr.Blocks() as demo:
    gr.Markdown(
        "# Base model explorer: explore the lineage of models on the  &#129303; Hub"
    )
    gr.Markdown(
        """When sharing models to the Hub, it is possible to [specify a base model in the model card](https://huggingface.co/docs/hub/model-cards#specifying-a-base-model), i.e. that your model is a fine-tuned version of [bert-base-cased](https://huggingface.co/bert-base-cased). 
        This Space allows you to find children's models for a given base model and view the popularity of models for fine-tuning.
        You can also optionally filter by the task to see rankings for a particular machine learning task.
        Don't forget to  &#10084;  if you like this space &#129303;"""
    )

    gr.Markdown(produce_summary())
    gr.Markdown("## Find all models trained from a base model")
    base_model = gr.Dropdown(
        all_base_models[:100], label="Base Model", allow_custom_value=True
    )
    results = gr.Markdown()
    base_model.change(return_models_for_base_model, base_model, results)
    gr.Markdown("## Base model rankings ")
    dropdown = gr.Dropdown(
        choices=unique_pipeline_tags,
        value=None,
        label="Filter rankings by task pipeline",
    )
    with gr.Accordion("Base model popularity ranking", open=False):
        df_popularity = gr.DataFrame(
            return_base_model_popularity(None), datatype="markdown"
        )
        dropdown.change(return_base_model_popularity, dropdown, df_popularity)
    with gr.Accordion("Base model popularity ranking by organization", open=False):
        df_popularity_org = gr.DataFrame(
            return_base_model_popularity_by_org(None), datatype="markdown"
        )
        dropdown.change(
            return_base_model_popularity_by_org, dropdown, df_popularity_org
        )


demo.launch()