# import packages
import os
from tqdm import tqdm
import warnings
import json
import torch.nn.functional as F
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from datetime import datetime
import argparse
import mamba_ssm
import rwkv


RWKV4_TOKENIZER_FILE = "./support/20B_tokenizer.json"


def load_list_from_json(file_path):
    """
    Loads a list of strings from a JSON file.

    :param file_path: Path of the JSON file to be loaded.
    :return: List of strings loaded from the JSON file.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)


def calculate_log_sum(logits, target_token_ids):
    shifted_logits = logits[:-1, :]
    shifted_targets = target_token_ids[1:]

    log_probs = F.log_softmax(shifted_logits, dim=-1)

    target_log_probs = -log_probs.gather(1, shifted_targets.unsqueeze(1)).squeeze()
    # print(target_log_probs)

    log_sum = torch.sum(target_log_probs, dim=-1)
    # print(perplexity_sum)

    return log_sum.item()


def print_model_parameters_in_billions(model):
    total_params = sum(p.numel() for p in model.parameters())

    total_params_billion = total_params / 1e9

    print(f"Model parameters: {total_params_billion:.3f} billion")


def make_log(data_dict, folder_path):
    if not os.path.exists(folder_path):
        try:
            os.makedirs(folder_path)
            print(f"Directory created at {folder_path}")
        except Exception as e:
            print(f"Error creating directory: {e}")
            return

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    file_name = f"{timestamp}.json"
    file_path = os.path.join(folder_path, file_name)

    try:
        with open(file_path, 'w') as file:
            json.dump(data_dict, file, indent=4)
        print(f"Dictionary saved successfully to {file_path}")
    except Exception as e:
        print(f"Error saving dictionary: {e}")


def load_rwkv(path):
    os.environ['RWKV_JIT_ON'] = '1'
    os.environ["RWKV_CUDA_ON"] = '1'

    from rwkv.model import RWKV
    from rwkv.utils import PIPELINE

    rwkv_model = RWKV(model=path, strategy='cuda fp16')
    rwkv_pipeline = PIPELINE(rwkv_model, r"rwkv_vocab_v20230424")
    rwkv_tokenizer = rwkv_pipeline.tokenizer

    return rwkv_model, rwkv_tokenizer


def load_rwkv4pile(path):
    os.environ['RWKV_JIT_ON'] = '1'
    os.environ["RWKV_CUDA_ON"] = '1'

    from rwkv.model import RWKV
    from rwkv.utils import PIPELINE

    rwkv_model = RWKV(model=path, strategy='cuda fp16')
    rwkv_pipeline = PIPELINE(rwkv_model, RWKV4_TOKENIZER_FILE)
    rwkv_tokenizer = rwkv_pipeline.tokenizer

    return rwkv_model, rwkv_tokenizer


def load_hf_model(path, cache_path):
    hf_tokenizer = AutoTokenizer.from_pretrained(path)
    if cache_path is not None:
        hf_model = AutoModelForCausalLM.from_pretrained(path,
                                                        device_map="cuda",
                                                        trust_remote_code=True,
                                                        cache_dir=cache_path).eval()
    else:
        hf_model = AutoModelForCausalLM.from_pretrained(path,
                                                        device_map="cuda",
                                                        trust_remote_code=True).eval()

    print_model_parameters_in_billions(hf_model)

    return hf_model, hf_tokenizer


def load_mamba(path):
    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

    mamba_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    mamba_model = MambaLMHeadModel.from_pretrained(path, device="cuda", dtype=torch.float16)
    mamba_model.device = torch.device('cuda')

    print_model_parameters_in_billions(mamba_model)

    return mamba_model, mamba_tokenizer


def eval_rwkv(model, tokenizer, texts, chunk_size, v4pile=False):
    rwkv_test_data = []
    rwkv_token_length_list = []

    for idx, sample in tqdm(enumerate(texts), total=len(texts)):

        with torch.no_grad():

            if v4pile:
                input_seq = tokenizer.encode(sample).ids  # v4
            else:
                input_seq = tokenizer.encode(sample)

            input_length = len(input_seq)

            neg_log_prob_temp = 0
            # for begin in range(0, input_length, chunk_size):
            input_chunk = input_seq[:chunk_size]

            logit = model.forward(input_chunk, None, full_output=True)[0]

            if len(input_chunk) == 1:
                logit = logit.unsqueeze(0)

    #             log_sum = calculate_log_sum(logit, torch.tensor(input_chunk).cuda())

    #             neg_log_prob_temp += log_sum

    #         rwkv_token_length_list.append(input_length)
    #         rwkv_test_data.append(neg_log_prob_temp)

    # data_dict = {
    #     'neg_log_prob_sum': sum(rwkv_test_data) / len(rwkv_test_data),
    #     'avg tokens': sum(rwkv_token_length_list) / len(rwkv_token_length_list),
    # }

    # print(f'log probability sum: {sum(rwkv_test_data) / len(rwkv_test_data):.2f}')
    # print(f'avg tokens: {sum(rwkv_token_length_list) / len(rwkv_token_length_list):.0f}')

    return logit,logit,input_chunk,tokenizer


def eval_hf_model(model, tokenizer, texts, chunk_size):
    data = []
    token_length_list = []

    for idx, sample in tqdm(enumerate(texts), total=len(texts)):

        with torch.no_grad():

            inputs = tokenizer(sample, return_tensors='pt')
            inputs = inputs.to(model.device)

            seq_length = inputs['input_ids'].shape[-1]

            neg_log_prob_temp = 0
            # for begin in range(0, seq_length, chunk_size):
            input_chunk = inputs['input_ids'][:, :chunk_size]

            logit = model.forward(input_ids=input_chunk).logits[0, :, :]

    #             log_sum = calculate_log_sum(logit, input_chunk.squeeze(0))
    #             neg_log_prob_temp += log_sum

    #         token_length_list.append(seq_length)
    #         data.append(neg_log_prob_temp)

    # data_dict = {
    #     'neg_log_prob_sum': sum(data) / len(data),
    #     'avg tokens': sum(token_length_list) / len(token_length_list),
    # }

    # print(f'log probability sum: {sum(data) / len(data):.2f}')
    # print(f'avg tokens: {sum(token_length_list) / len(token_length_list):.0f}')

    return logit,input_chunk,tokenizer


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()

#     parser.add_argument('--model', type=str, required=True, help='model name or path')
#     parser.add_argument('--model_type', choices=['hf', 'rwkv', 'mamba', 'rwkv4pile'], required=True, help='model type')
#     parser.add_argument('--data', type=str, required=True, help='data path (json file)')
#     parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')
#     parser.add_argument('--model_cache', type=str, help='hugging face model cache')
#     parser.add_argument('--chunk_size', type=int, default=1024, help='chunk size')


def run_get_loss(args):
    # args = parser.parse_args()

    # load data
    texts = load_list_from_json(args.data)
    print(f'data size: {len(texts)}')

    # load model
    if args.model_type == 'hf':
        model, tokenizer = load_hf_model(args.model, args.model_cache)# tokenzier path, model path
    elif args.model_type == 'rwkv':
        model, tokenizer = load_rwkv(args.model)
    elif args.model_type == 'mamba':
        model, tokenizer = load_mamba(args.model)
    elif args.model_type == 'rwkv4pile':
        model, tokenizer = load_rwkv4pile(args.model)
    else:
        raise NotImplementedError

    # eval
    if args.model_type in ['hf', 'mamba']:
        return eval_hf_model(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
    elif args.model_type == 'rwkv':
        return eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
    elif args.model_type == 'rwkv4pile':
        return eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size, v4pile=True)
    else:
        raise NotImplementedError

    # results['model_name_or_path'] = args.model
    # results['data_path'] = args.data
    # results['chunk_size'] = args.chunk_size

    # make_log(results, args.log_path)

    # print(json.dumps(results, indent=4, ensure_ascii=False))

from types import SimpleNamespace

if __name__ == '__main__':
    args=SimpleNamespace(model='microsoft/phi-2',texts=['Hello FreshBench !'],model_type='hf',data='data.json',model_cache=None,chunk_size=1024)



# def run_get_loss(input_string, model_type):
#     # load data
#     texts = [input_string]
#     print(f'data size: {len(texts)}')

#     # load model
#     if model_type == 'hf':
#         model, tokenizer = load_hf_model(args.model, args.model_cache)# tokenzier path, model path
#     elif model_type == 'rwkv':
#         model, tokenizer = load_rwkv(args.model)
#     elif model_type == 'mamba':
#         model, tokenizer = load_mamba(args.model)
#     elif model_type == 'rwkv4pile':
#         model, tokenizer = load_rwkv4pile(args.model)
#     else:
#         raise NotImplementedError

#     # eval
#     if model_type in ['hf', 'mamba']:
#         results = eval_hf_model(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
#     elif model_type == 'rwkv':
#         results = eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
#     elif model_type == 'rwkv4pile':
#         results = eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size, v4pile=True)
#     else:
#         raise NotImplementedError

#     results['model_name_or_path'] = args.model
#     results['data_path'] = args.data
#     results['chunk_size'] = args.chunk_size

#     make_log(results, args.log_path)

#     print(json.dumps(results, indent=4, ensure_ascii=False))