"""
## more statistics
code:
math:
digit:
whitespace:
top_oov: most frequent oov chars
ranking: thumb_up thumb_down
"""

import json
import os
import sys
from difflib import SequenceMatcher
import pandas as pd
from datasets import load_dataset
from utils.log_util import logger
from vocab import tokenizer_factory, TokenizerConfig
from typing import List, Optional, Union, Literal

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

common_units = ["g_bytes/b_tokens", "b_tokens/g_bytes", "t_bytes/t_tokens", "t_tokens/t_bytes", "n_chars/n_tokens", ]

common_corpuses = sorted(["cc100/en", "cc100/zh-Hans", "cc100/es", "cc100/fr", "cc100/de", "cc100/ko",
                          "cc100/fa", "cc100/ar", "cc100/ja"])

VALID_CODES_CC100 = [
    "am", "ar", "as", "az", "be", "bg", "bn", "bn_rom", "br", "bs", "ca", "cs", "cy", "da", "de",
    "el", "en", "eo", "es", "et", "eu", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gn", "gu",
    "ha", "he", "hi", "hi_rom", "hr", "ht", "hu", "hy", "id", "ig", "is", "it", "ja", "jv", "ka",
    "kk", "km", "kn", "ko", "ku", "ky", "la", "lg", "li", "ln", "lo", "lt", "lv", "mg", "mk", "ml",
    "mn", "mr", "ms", "my", "my_zaw", "ne", "nl", "no", "ns", "om", "or", "pa", "pl", "ps", "pt",
    "qu", "rm", "ro", "ru", "sa", "si", "sc", "sd", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv",
    "sw", "ta", "ta_rom", "te", "te_rom", "th", "tl", "tn", "tr", "ug", "uk", "ur", "ur_rom", "uz",
    "vi", "wo", "xh", "yi", "yo", "zh-Hans", "zh-Hant", "zu",
]


# code: https://huggingface.co/datasets/codeparrot/github-code-clean  python java c sql html
# math:

def get_n_bytes_of_string(string_text):
    n_bytes = len(string_text.encode("utf-8"))
    return n_bytes


def unit_convertor(stat, unit):
    n_tokens = stat["_n_tokens"]
    n_chars = stat["_n_chars"]
    n_bytes = stat["_n_bytes"]

    if n_tokens is None:
        return None

    n_tokens_in_billion = n_tokens / (1000 * 1000 * 1000)
    n_tokens_in_trillion = n_tokens / (1000 * 1000 * 1000 * 1000)
    n_bytes_in_mb = n_bytes / (1024 * 1024)
    n_bytes_in_gb = n_bytes_in_mb / 1024
    n_bytes_in_tb = n_bytes_in_gb / 1024
    # n_chars_in_billion = n_chars / (1000 * 1000 * 1000)

    if unit == "n_tokens/n_bytes":
        value = n_tokens / n_bytes
    elif unit in ["char/token", "chars_per_token"]:  # 重要:平均一个token包含多少个字符。
        value = n_chars / n_tokens
    elif unit in ["token/char", "tokens_per_char"]:  # 一个中文汉字需要几个token?
        value = n_tokens / n_chars
    elif unit == "g_bytes/b_tokens":
        value = n_bytes_in_gb / n_tokens_in_billion
    elif unit == "b_tokens/g_bytes":
        value = n_tokens_in_billion / n_bytes_in_gb
    elif unit == "t_bytes/t_tokens":  # 重要:
        value = n_bytes_in_tb / n_tokens_in_trillion
    elif unit == "t_tokens/t_bytes":
        value = n_tokens_in_trillion / n_bytes_in_tb
    else:
        raise "measure not support"
    return round(value, 3)


def _merge_stats_by_corpus(stats_by_corpus, oov_threshold=0.3):
    """
    """
    all_stats = list(stats_by_corpus.values())
    assert len(set([stats["tokenizer"] for stats in all_stats])) == 1
    lossless = all(stat['lossless'] for stat in all_stats)
    is_support = all(stat['oov_ratio'] < oov_threshold for stat in all_stats)

    merged_stats = {
        "tokenizer": all_stats[0]["tokenizer"],
        "organization": all_stats[0]["organization"],
        "vocab_size": all_stats[0]["vocab_size"],
        "_n_bytes": 0,
        "_n_tokens": 0 if is_support else None,
        "_n_chars": 0,
        "_n_oov_chars": 0,
        "lossless": True,
    }
    for stats in all_stats:
        merged_stats["_n_bytes"] += stats["_n_bytes"]
        merged_stats["_n_chars"] += stats["_n_chars"]
        if is_support:  # The number of tokens cannot be accurately counted, when there are too many UNKs.
            merged_stats["_n_tokens"] += stats["_n_tokens"]
        merged_stats["_n_oov_chars"] += stats["_n_oov_chars"]
        merged_stats["lossless"] &= stats['lossless']

    merged_stats.update({
        "oov_ratio": float("%.4g" % (stats["_n_oov_chars"] / stats["_n_chars"])),
        "lossless": lossless
    })
    return merged_stats


def to_dataframe(stats, units=None):
    if units is None:
        units = common_units
    elif not isinstance(units, list):
        units = [units]
    table = []

    for stat in stats.values():
        columns = {k: v for k, v in stat.items() if not k.startswith("_")}
        for unit in units:
            if unit not in stat:
                columns[unit] = unit_convertor(stat, unit)
            else:
                logger.error(f"unit {unit} not support")
        table.append(columns)
    df = pd.DataFrame(table)
    return df


cache = {}


def tokenize_corpus(
        tokenizer_name: str,
        corpuses: List[str],
        cache_dir: str = "stats"
) -> dict:
    """
    :param tokenizer_name:
    :param corpuses:
    :param cache_dir:
    :return:
    """

    def _assert_oov(tokenizer, oov_candidate):

        tokenizer.encode()

    def _char_based_oov(src_text, decoded_text, tokenizer):
        oov_charset = []  # keep the order in src_text
        decoded_charset = set(decoded_text)
        for char in dict.fromkeys(src_text):
            if char not in decoded_charset \
                    and char != tokenizer.decode(tokenizer.encode(char, add_special_tokens=False)):
                oov_charset.append(char)

        n_oov_chars = sum([1 for char in src_text if char in oov_charset])
        return n_oov_chars, oov_charset

    def _diff_path(src_text, decoded_text):
        s = SequenceMatcher(a=src_text, b=decoded_text)
        changes = []
        for tag, i1, i2, j1, j2 in s.get_opcodes():
            if tag != "equal":
                changes.append('{:7}   text[{}:{}] --> decoded_text[{}:{}] {!r:>8} --> {!r}'.format(
                    tag, i1, i2, j1, j2, src_text[i1:i2], decoded_text[j1:j2]))
        return changes

    def _tokenize(tokenizer, datasets, detail_path=None):
        """
        :param tokenizer:
        :param datasets:
        :param detail_path:
        :return:
        """
        n_bytes = 0
        n_tokens = 0
        n_chars = 0
        n_oov_chars = 0
        diff_details = []
        oov_charset = set()
        unk_token_id = None
        if hasattr(tokenizer, "unk_token"):
            unk_token_id = tokenizer.unk_token_id
        for dataset in datasets:
            for item in dataset:
                text = item["text"]
                n_bytes += get_n_bytes_of_string(text)
                n_chars += len(text)
                ids = tokenizer.encode(text, add_special_tokens=False)

                # detect oov
                decoded_text = tokenizer.decode(ids)
                decoded_text_without_unk = tokenizer.decode([token_id for token_id in ids if token_id != unk_token_id])
                if decoded_text != text:
                    _n_oov_chars, _oov_charset = _char_based_oov(text, decoded_text_without_unk, tokenizer)
                    diffs = _diff_path(text, decoded_text)
                    diff_details.append(
                        {
                            "text": text,
                            "decoded_text": decoded_text,
                            "diff": diffs,
                            "n_oov_chars": _n_oov_chars,
                            'oov_ratio': _n_oov_chars / len(text),
                            'oov_charset': json.dumps(_oov_charset, ensure_ascii=False),
                        }
                    )
                    n_oov_chars += _n_oov_chars
                    oov_charset.update(_oov_charset)
                n_tokens += len(ids)
        stat = {
            "_n_bytes": n_bytes,
            "_n_tokens": n_tokens,
            "_n_chars": n_chars,
            "_n_oov_chars": n_oov_chars,
            "oov_ratio": n_oov_chars / n_chars,
            '_oov_charset': json.dumps(list(oov_charset), ensure_ascii=False),
            "lossless": len(diff_details) == 0
        }

        if not diff_details:
            logger.info("no diff found")
        elif detail_path:
            logger.info(f"saving tokenization detail to '{detail_path}'")
            with open(detail_path, "w", encoding="utf-8") as f:
                f.write(json.dumps(diff_details, ensure_ascii=False, indent=2))
                # print(f"{tokenizer_config.name_or_path}, {infer_tokenizer_type(tokenizer_config)}\n"
                #       f"lossless: false; unk_token: {get_unk(tokenizer_config)},"
                #       f" unk_ratio: {unk_count / len(encoding):.4f}; oov: []")
                # for diff_detail in diff_details:
                #         # print(f"text[{i}]     = {str(bytes(text[i:], 'utf-8'))}\n"
                #         #       f"decoding[{i}] = {str(bytes(decoding[i:], 'utf-8'))}")
                #         f.write(f"text= {json.dumps(text[i:], ensure_ascii=False)}, \n"
                #               f"decoding[{i}] = {json.dumps(decoding[i:], ensure_ascii=False)}")
        return stat

    # load from cache
    cache_id = f"{tokenizer_name} @ {'.'.join(corpuses)}"
    cache_path = os.path.join(cache_dir, "compression_rate.json")
    if not cache and os.path.exists(cache_path):
        with open(cache_path, "r", encoding="utf-8") as f_tmp:
            cache.update(json.load(f_tmp))
    if cache_id in cache:
        # logger.info(f"loading {cache_id} from in-memory cache")
        return cache[cache_id]

    # tokenize corpus
    tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
    datasets = [load_dataset("eson/cc100-samples", corpus.replace("cc100/", ""), split="train") for corpus in corpuses]

    stat = {
        "tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name),
        "organization": tokenizer_factory.get_tokenizer_config(tokenizer_name).org,
        "vocab_size": len(tokenizer),
    }
    tokenize_detail_dir = os.path.join(cache_dir, "compression_rate")
    os.makedirs(tokenize_detail_dir, exist_ok=True)
    tokenize_detail_path = os.path.join(tokenize_detail_dir, cache_id.replace("/", ".") + ".diff.json")
    stat.update(_tokenize(tokenizer, datasets, detail_path=tokenize_detail_path))
    # add basic info

    # save to cache
    len_before = len(cache)
    cache[cache_id] = stat
    len_after = len(cache)
    logger.info(f"saving '{cache_id}' to memory and file cache '{cache_path}': {len_before}->{len_after}")
    with open(cache_path, "w", encoding="utf-8") as f_tmp:
        json.dump(cache, f_tmp, ensure_ascii=False, indent=2)
    return stat


def get_compression_leaderboard(
        corpuses: List[str] = ['cc100/en'],
        unit: str = "b_tokens/g_bytes",
        tokenizer_filter: Optional[str] = None,
        return_type: Optional[Literal["dict", "dataframe"]] = "dataframe"
) -> Union[pd.DataFrame, dict]:
    """
    """
    logger.info(f"corpuses: {corpuses}; unit: {unit}; tokenizer_filter: {tokenizer_filter}")
    stats = {}
    if tokenizer_filter is not None:
        tokenizer_names = [tokenizer_name for tokenizer_name in tokenizer_factory.all_tokenizer_names
                           if tokenizer_filter.lower() in tokenizer_name.lower()]
    else:
        tokenizer_names = tokenizer_factory.all_tokenizer_names
    for tokenizer_name in tokenizer_names:
        stats_by_corpus = {}
        for corpus in corpuses:
            stats_by_corpus[corpus] = tokenize_corpus(tokenizer_name, [corpus])
        stats[tokenizer_name] = _merge_stats_by_corpus(stats_by_corpus)

    if return_type == "dataframe":
        token_number_unit, file_size_unit = unit.split("/")
        reverse_unit = f"{file_size_unit}/{token_number_unit}"
        stats = to_dataframe(stats, ["char/token", unit, reverse_unit])
        stats = stats.sort_values(["oov_ratio", "char/token"], ascending=[True, False])
        # stats = stats.sort_values(["oov_ratio", unit], ascending=[True, True])
        stats = stats.rename(columns={"oov_ratio": f' ⬆️oov_ratio'}).rename(columns={"char/token": ' ⬇️char/token'})  #
    return stats


def main():
    if len(sys.argv) == 3:
        tokenizer_filter = [sys.argv[1]]
        corpuses = [sys.argv[2]]
    else:
        tokenizer_filter, corpuses = None, common_corpuses
        # tokenizer_filter, corpuses = "openai", ["cc100/en", "cc100/zh-Hans"]
        # tokenizer_filter, corpuses = "Qwen/Qwen1.5-14B", ["cc100/de"]
        # tokenizer_filter, corpuses = "Qwen/Qwen1.5-14B", ["cc100/ja"]  # oov 特别多
        # tokenizer_filter, corpuses = "google-bert/bert-base-uncased", ["cc100/ja", "cc100/zh-Hans"]  # oov 特别多
        # tokenizer_filter, corpuses = "deepseek-ai/DeepSeek-R1", ["cc100/de"]
    df = get_compression_leaderboard(corpuses, tokenizer_filter=tokenizer_filter)
    # print(df.to_markdown(index=False, tablefmt='fancy_grid'))
    logger.info(f"\n{df.to_markdown(index=False)}")


if __name__ == "__main__":
    main()