import os
import random
import string

import numpy as np
import torch
from torch.utils.data import Dataset

class TokenClfDataset(Dataset):     # Hàm tạo custom dataset
    def __init__(
        self,
        texts,
        max_len=512,    # 256 (phobert)  512 (xlm-roberta)
        tokenizer=None,
        model_name="m_bert",
    ):
        self.len = len(texts)
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.model_name = model_name
        if "m_bert" in model_name:
            self.cls_token = "[CLS]"
            self.sep_token = "[SEP]"
            self.unk_token = "[UNK]"
            self.pad_token = "[PAD]"
            self.mask_token = "[MASK]"
        elif "xlm-roberta-large" in model_name:
            self.bos_token = "<s>"
            self.eos_token = "</s>"
            self.sep_token = "</s>"
            self.cls_token = "<s>"
            self.unk_token = "<unk>"
            self.pad_token = "<pad>"
            self.mask_token = "<mask>"
        elif "xlm-roberta" in model_name:
            self.bos_token = "<s>"
            self.eos_token = "</s>"
            self.sep_token = "</s>"
            self.cls_token = "<s>"
            self.unk_token = "<unk>"
            self.pad_token = "<pad>"
            self.mask_token = "<mask>"
        elif "phobert" in model_name:
            self.bos_token = "<s>"
            self.eos_token = "</s>"
            self.sep_token = "</s>"
            self.cls_token = "<s>"
            self.unk_token = "<unk>"
            self.pad_token = "<pad>"
            self.mask_token = "<mask>"
        #else: raise NotImplementedError()

    def __getitem__(self, index):
        text = self.texts[index]
        tokenized_text = self.tokenizer.tokenize(text)

        tokenized_text = (
            [self.cls_token] + tokenized_text + [self.sep_token]
        )  # add special tokens

        if len(tokenized_text) > self.max_len:
            tokenized_text = tokenized_text[: self.max_len]
        else:
            tokenized_text = tokenized_text + [
                self.pad_token for _ in range(self.max_len - len(tokenized_text))
            ]

        attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]

        ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(attn_mask, dtype=torch.long),
        }

    def __len__(self):
        return self.len


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def is_begin_of_new_word(token, model_name, force_tokens, token_map):   # Thêm kí tự bắt đầu vào từ mới
    if "m_bert" in model_name:
        if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
            token_map.values()
        ):
            return True
        return not token.startswith("##")
    elif "xlm-roberta-large" in model_name:
        #print("xlm-roberta-large")
        if (
            token in string.punctuation
            or token in force_tokens
            or token in set(token_map.values())
        ):
            return True
        return token.startswith("▁")    # check xem token có bắt đầu bằng kí tự "_" hay ko  -> Trả về False
    elif "xlm-roberta" in model_name:
        #print("xlm-roberta-large")
        if (
            token in string.punctuation
            or token in force_tokens
            or token in set(token_map.values())
        ):
            return True
        return token.startswith("▁")  
    elif "phobert" in model_name:
        #print("minh phobert")
        #print("xlm-roberta-large")
        if (
            token in string.punctuation     # điều kiện hoặc
            or token in force_tokens
            or token in set(token_map.values())
        ):
            return True
        #return token.startswith("▁") # 
        #return not token.startswith("▁") 
        #return not token.startswith("@@")
        return not token.endswith("@@")
        #return token.startswith("@@")
    #else: raise NotImplementedError()

def replace_added_token(token, token_map):
    for ori_token, new_token in token_map.items():
        token = token.replace(new_token, ori_token)
    return token

def get_pure_token(token, model_name):  # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword)
    if "m_bert" in model_name:
        return token.lstrip("##")
    elif "xlm-roberta-large" in model_name:
        return token.lstrip("▁")        # bỏ kí tự "_" ở phía bên trái của từ
    elif "xlm-roberta" in model_name:
        return token.lstrip("▁")        # bỏ kí tự "_" ở phía bên trái của từ
    elif "phobert" in model_name:
        #return token.lstrip("▁")
        #return token.lstrip("@@")
        return token.rstrip("@@")
    # else: raise NotImplementedError()