import os
from datetime import datetime
from typing import Any, Dict, List

import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import HfApi
from huggingface_hub.utils import logging
from tqdm.auto import tqdm

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
USER_AGENT = os.getenv("USER_AGENT")
assert (
    USER_AGENT is not None
), "You need to set USER_AGENT in your environment variables"

logger = logging.get_logger(__name__)

api = HfApi(token=HF_TOKEN)
MAX_DATASETS = None


def has_card_data(dataset):
    return hasattr(dataset, "card_data")


def check_dataset_has_dataset_info(dataset):
    return bool(
        has_card_data(dataset)
        and hasattr(dataset.card_data, "dataset_info")
        and dataset.card_data.dataset_info is not None
    )


def parse_single_config_dataset(data):
    config_name = data.get("config_name", "default")
    features = data.get("features", [])
    column_names = [feature.get("name") for feature in features]
    return {
        "config_name": config_name,
        "column_names": column_names,
        "features": features,
    }


def parse_multiple_config_dataset(data: List[Dict[str, Any]]):
    return [parse_single_config_dataset(d) for d in data]


def parse_dataset(dataset):
    hub_id = dataset.id
    likes = dataset.likes
    downloads = dataset.downloads
    tags = dataset.tags
    created_at = dataset.created_at
    last_modified = dataset.last_modified
    license = dataset.card_data.license
    language = dataset.card_data.language
    return {
        "hub_id": hub_id,
        "likes": likes,
        "downloads": downloads,
        "tags": tags,
        "created_at": created_at,
        "last_modified": last_modified,
        "license": license,
        "language": language,
    }


def parsed_column_info(dataset_info):
    if isinstance(dataset_info, dict):
        return [parse_single_config_dataset(dataset_info)]
    elif isinstance(dataset_info, list):
        return parse_multiple_config_dataset(dataset_info)
    return None


def ensure_list_of_strings(value):
    if value is None:
        return []
    if isinstance(value, list):
        return [str(item) for item in value]
    return [str(value)]


def refresh_data() -> List[Dict[str, Any]]:
    # current date as string
    now = datetime.now()
    # check if a file for the current date exists
    if os.path.exists(f"datasets_{now.strftime('%Y-%m-%d')}.parquet"):
        df = pd.read_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")
        return df.to_dict(orient="records")

    # List all datasets
    datasets = list(api.list_datasets(limit=MAX_DATASETS, full=True))

    # Filter datasets with dataset info
    datasets = [
        dataset for dataset in tqdm(datasets) if check_dataset_has_dataset_info(dataset)
    ]

    parsed_datasets = []
    for dataset in tqdm(datasets):
        try:
            datasetinfo = parse_dataset(dataset)
            column_info = parsed_column_info(dataset.card_data.dataset_info)
            parsed_datasets.extend({**datasetinfo, **info} for info in column_info)
        except Exception as e:
            print(f"Error processing dataset {dataset.id}: {e}")
            continue

    # Convert to DataFrame
    df = pd.DataFrame(parsed_datasets)

    # Ensure 'license', 'tags', and 'language' are lists of strings
    df["license"] = df["license"].apply(ensure_list_of_strings)
    df["tags"] = df["tags"].apply(ensure_list_of_strings)
    df["language"] = df["language"].apply(ensure_list_of_strings)

    # Ensure 'column_names' is a list
    df["column_names"] = df["column_names"].apply(
        lambda x: x if isinstance(x, list) else []
    )

    df = df.astype({"hub_id": "string", "config_name": "string"})

    # save to parquet file with current date
    # df.to_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")

    # # save to JSON file with current date
    # df.to_json(
    #     f"datasets_{now.strftime('%Y-%m-%d')}.json", orient="records", lines=True
    # )

    # return a list of dictionaries
    return df.to_dict(orient="records")