from typing import Any, Dict, Iterator, List

import requests
from huggingface_hub import add_collection_item, create_collection
from tqdm.auto import tqdm


class DatasetSearchClient:
    def __init__(
        self,
        base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space",
    ):
        self.base_url = base_url

    def search(
        self, columns: List[str], match_all: bool = False, page_size: int = 100
    ) -> Iterator[Dict[str, Any]]:
        """
        Search datasets using the provided API, automatically handling pagination.

        Args:
            columns (List[str]): List of column names to search for.
            match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False.
            page_size (int, optional): Number of results per page. Defaults to 100.

        Yields:
            Dict[str, Any]: Each dataset result from all pages.

        Raises:
            requests.RequestException: If there's an error with the HTTP request.
            ValueError: If the API returns an unexpected response format.
        """
        page = 1
        total_results = None

        while total_results is None or (page - 1) * page_size < total_results:
            params = {
                "columns": columns,
                "match_all": str(match_all).lower(),
                "page": page,
                "page_size": page_size,
            }

            try:
                response = requests.get(f"{self.base_url}/search", params=params)
                response.raise_for_status()
                data = response.json()

                if not {"total", "page", "page_size", "results"}.issubset(data.keys()):
                    raise ValueError("Unexpected response format from the API")

                if total_results is None:
                    total_results = data["total"]

                yield from data["results"]
                page += 1

            except requests.RequestException as e:
                raise requests.RequestException(
                    f"Error connecting to the API: {str(e)}"
                ) from e
            except ValueError as e:
                raise ValueError(f"Error processing API response: {str(e)}") from e


# Create an instance of the client
client = DatasetSearchClient()


def update_collection_for_dataset(
    collection_name: str = None,
    dataset_columns: List[str] = None,
    collection_description: str = None,
    collection_namespace: str = None,
):
    if not collection_name:
        collection = create_collection(
            collection_name, exists_ok=True, description=collection_description
        )
    else:
        collection = create_collection(
            collection_name,
            exists_ok=True,
            description=collection_description,
            namespace=collection_namespace,
        )
    results = list(
        tqdm(
            client.search(dataset_columns, match_all=True),
            desc="Searching datasets...",
            leave=False,
        )
    )
    for result in tqdm(results, desc="Adding datasets to collection...", leave=False):
        try:
            add_collection_item(
                collection.slug, result["hub_id"], item_type="dataset", exists_ok=True
            )
        except Exception as e:
            print(
                f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}"
            )
    return f"https://huggingface.co/collections/{collection.slug}"


collections = [
    {
        "dataset_columns": ["chosen", "rejected", "prompt"],
        "collection_description": "Datasets suitable for DPO based on having 'chosen', 'rejected', and 'prompt' columns. Created using librarian-bots/dataset-column-search-api",
        "collection_name": "Direct Preference Optimization Datasets",
    },
    {
        "dataset_columns": ["image", "chosen", "rejected"],
        "collection_description": "Datasets suitable for Image Preference Optimization based on having  'image','chosen', and 'rejected' columns",
        "collection_name": "Image Preference Optimization Datasets",
    },
    {
        "collection_name": "Alpaca Style Datasets",
        "dataset_columns": ["instruction", "input", "output"],
        "collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns",
    },
]

# results = [
#     update_collection_for_dataset(**collection, collection_namespace="librarian-bots")
#     for collection in collections
# ]
# print(results)