import os
import random
from statistics import mean
from typing import Iterator, Union, Any
import fasttext
import gradio as gr
from dotenv import load_dotenv
from httpx import Client, Timeout
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import logging
from toolz import concat, groupby, valmap
from fastapi import FastAPI
from httpx import AsyncClient
from pathlib import Path

app = FastAPI()
logger = logging.get_logger(__name__)
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")


BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
headers = {
    "authorization": f"Bearer ${HF_TOKEN}",
}
timeout = Timeout(60, read=120)
client = Client(headers=headers, timeout=timeout)
async_client = AsyncClient(headers=headers, timeout=timeout)
# non exhaustive list of columns that might contain text which can be used for language detection
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
TARGET_COLUMN_NAMES = {
    "text",
    "input",
    "tokens",
    "prompt",
    "instruction",
    "sentence_1",
    "question",
    "sentence2",
    "answer",
    "sentence",
    "response",
    "context",
    "query",
    "chosen",
    "rejected",
}


def datasets_server_valid_rows(hub_id: str):
    try:
        resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
        return resp.json()["viewer"]
    except Exception as e:
        logger.error(f"Failed to get is-valid for {hub_id}: {e}")
        return False


async def get_first_config_and_split_name(hub_id: str):
    try:
        resp = await async_client.get(
            f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
        )

        data = resp.json()
        return data["splits"][0]["config"], data["splits"][0]["split"]
    except Exception as e:
        logger.error(f"Failed to get splits for {hub_id}: {e}")
        return None


async def get_dataset_info(hub_id: str, config: str | None = None):
    if config is None:
        config = get_first_config_and_split_name(hub_id)
        if config is None:
            return None
        else:
            config = config[0]
    resp = await async_client.get(
        f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
    )
    resp.raise_for_status()
    return resp.json()


async def get_random_rows(
    hub_id: str,
    total_length: int,
    number_of_rows: int,
    max_request_calls: int,
    config="default",
    split="train",
):
    rows = []
    rows_per_call = min(
        number_of_rows // max_request_calls, total_length // max_request_calls
    )
    rows_per_call = min(rows_per_call, 100)  # Ensure rows_per_call is not more than 100
    for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
        offset = random.randint(0, total_length - rows_per_call)
        url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
        logger.info(f"Fetching {url}")
        print(url)
        response = await async_client.get(url)
        if response.status_code == 200:
            data = response.json()
            batch_rows = data.get("rows")
            rows.extend(batch_rows)
        else:
            print(f"Failed to fetch data: {response.status_code}")
            print(url)
        if len(rows) >= number_of_rows:
            break
    return [row.get("row") for row in rows]


def load_model(repo_id: str) -> fasttext.FastText._FastText:
    model_path = hf_hub_download(repo_id, filename="model.bin")
    return fasttext.load_model(model_path)


def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
    for row in rows:
        if isinstance(row, str):
            # split on lines and remove empty lines
            line = row.split("\n")
            for line in line:
                if line:
                    yield line
        elif isinstance(row, list):
            try:
                line = " ".join(row)
                if len(line) < min_length:
                    continue
                else:
                    yield line
            except TypeError:
                continue


FASTTEXT_PREFIX_LENGTH = 9  # fasttext labels are formatted like "__label__eng_Latn"

# model = load_model(DEFAULT_FAST_TEXT_MODEL)
Path("code/models").mkdir(parents=True, exist_ok=True)
model = fasttext.load_model(
    hf_hub_download(
        "facebook/fasttext-language-identification",
        "model.bin",
        cache_dir="code/models",
        local_dir="code/models",
        local_dir_use_symlinks=False,
    )
)


def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
    predictions = model.predict(inputs, k=k)
    return [
        {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
        for label, prob in zip(predictions[0], predictions[1])
    ]


def get_label(x):
    return x.get("label")


def get_mean_score(preds):
    return mean([pred.get("score") for pred in preds])


def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
    """Filter a dict to include items whose value is above `threshold_percent`"""
    total = sum(counts_dict.values())
    threshold = total * threshold_percent
    return {k for k, v in counts_dict.items() if v >= threshold}


def predict_rows(rows, target_column, language_threshold_percent=0.2):
    rows = (row.get(target_column) for row in rows)
    rows = (row for row in rows if row is not None)
    rows = list(yield_clean_rows(rows))
    predictions = [model_predict(row) for row in rows]
    predictions = [pred for pred in predictions if pred is not None]
    predictions = list(concat(predictions))
    predictions_by_lang = groupby(get_label, predictions)
    langues_counts = valmap(len, predictions_by_lang)
    keys_to_keep = filter_by_frequency(
        langues_counts, threshold_percent=language_threshold_percent
    )
    filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
    return {
        "predictions": dict(valmap(get_mean_score, filtered_dict)),
        "pred": predictions,
    }


@app.get("/items/{hub_id}")
async def predict_language(
    hub_id: str,
    config: str | None = None,
    split: str | None = None,
    max_request_calls: int = 10,
    number_of_rows: int = 1000,
) -> dict[Any, Any]:
    is_valid = datasets_server_valid_rows(hub_id)
    if not is_valid:
        gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
    if not config:
        config, split = await get_first_config_and_split_name(hub_id)
    info = await get_dataset_info(hub_id, config)
    if info is None:
        gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
    if dataset_info := info.get("dataset_info"):
        total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
        features = dataset_info.get("features")
        column_names = set(features.keys())
        logger.info(f"Column names: {column_names}")
        if not set(column_names).intersection(TARGET_COLUMN_NAMES):
            raise gr.Error(
                f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
            )
        for column in TARGET_COLUMN_NAMES:
            if column in column_names:
                target_column = column
                logger.info(f"Using column {target_column} for language detection")
                break
        random_rows = await get_random_rows(
            hub_id,
            total_rows_for_split,
            number_of_rows,
            max_request_calls,
            config,
            split,
        )
        logger.info(f"Predicting language for {len(random_rows)} rows")
        predictions = predict_rows(random_rows, target_column)
        predictions["hub_id"] = hub_id
        predictions["config"] = config
        predictions["split"] = split
        return predictions


@app.get("/")
def read_root():
    return {"Hello": "World!"}


# app_title = "Dataset Language Detection"
# app_description = "Detect the language of a dataset on the Hub"
# inputs = [
#     gr.Text(label="dataset id"),
#     gr.Textbox(
#         None,
#         label="config",
#     ),
#     gr.Textbox(None, label="split"),
# ]
# interface = gr.Interface(
#     predict_language,
#     inputs=inputs,
#     outputs="json",
#     title=app_title,
#     article=app_description,
# )
# interface.queue()
# interface.launch()