import zhipuai
import traceback
import pandas as pd
from tqdm import *
import re
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2'
import random
import time
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from transformers.generation.utils import GenerationConfig

class GLM3_6B_API():
    '''
    GLM3_6B_API defined by yourself
    '''
    def __init__(self) -> None:
        self.model_name_or_path = "your_model_path"
        self.init = True
    
    def chat(self, prompt) -> str:
        for _ in range(5):
            if self.init:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name_or_path,
                    trust_remote_code=True,
                    device_map="auto",
                    torch_dtype=(
                        torch.bfloat16
                        if torch.cuda.is_bf16_supported()
                        else torch.float32
                    ),
                ).eval()
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name_or_path,
                    trust_remote_code=True,
                    use_fast=True,
                    add_bos_token=False,
                    add_eos_token=False,
                    padding_side="left",
                )
                self.init = False
            try:
                print(prompt)
                response, re_history, probabilities = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
                print(response)
                return response
            except:
                traceback.print_exc()
                time.sleep(5)
                continue
        return None
    
glm3_6b = GLM3_6B_API()

def parse_num(res, min_score, max_score):
    """
    Extract the integers within the specified range from the evaluation result.
    Input: A string
    Output: A score within the specified range or exception(-1)
    If no numbers appear: return -1
    If a fraction appears, match the numerator; exit if it falls within the range, otherwise continue.
    If "out of" appears, match the preceding number; exit if it falls within the range, otherwise continue.
    Extract the first number that falls within the specified range from all appearing numbers; exit if it falls within the range, otherwise continue.
    If no numbers fall within the specified range, return -1.
    """
    all_nums = re.findall(r"-?\d+(?:\.\d+)?", res) 

    probs1_nums = re.finditer(r"\b(\d+(\.\d+)?)/\d+\b" , res) # extract fraction
    
    probs2_nums = re.finditer(r"\b(\d+(\.\d+)?)\s+out\s+of\s+\d+\b" , res) # extract "out of"

    if len(all_nums) == 0:
        print("this res doesn't have num! \n", res)
        return -1

    answer = -1

    for match in probs1_nums:
        answer = match.group(1)

    if float(answer) >= min_score and float(answer) <= max_score:
        return answer
    else:
        for match in probs2_nums:
            answer = match.group(1)
        if float(answer) >= min_score and float(answer) <= max_score:
            return answer
        else:
            for num in all_nums:
                if float(num) >= min_score and float(num) <= max_score:  # the specified range
                    answer = num
                    return answer
    
    print("this res doesn't have right num! ", res)
    return -1

def get_prompt(taskId):
    """
    Find the corresponding prompt based on the taskId.
    """
    prompt = ""
    if taskId == 0:
        prompt = open("prompt/prompt_Dialog.txt", encoding='utf-8').read().strip()
    elif taskId == 1:
        prompt = open("prompt/prompt_Story.txt", encoding='utf-8').read().strip()
    elif taskId == 2:
        prompt = open("prompt/prompt_Xsum.txt", encoding='utf-8').read().strip()
    elif taskId == 3:
        prompt = open("prompt/prompt_NFCATS.txt", encoding='utf-8').read().strip()
    return prompt

def get_model_score(taskId, question, answer, model):
    """
    pointwise 5-level as an example
    """
    prompt = get_prompt(taskId)
    prompt = prompt.replace("{{question_text}}", question)
    prompt = prompt.replace("{{answer_text}}", answer)
    result = model.chat(prompt)
    score = int(parse_num(result, 1, 5))
    if score == -1:
        score = random.randint(1,5)
    return score

def get_rank(data):
    """
    Calculate the rankings in descending order, and for ties, assign the lowest rank. For example, the ranking for [1,1,2] would be [2,2,1].
    """
    series = pd.Series(data)
    ranks = series.rank(method='min', ascending=False)

    return list(map(int, ranks.tolist()))

def get_output(path, model):
    """
    Obtain the results of the test set from the specified path.
    """
    df = pd.read_csv(path)
    row_labels = df.index

    # taskId,taskName,questionId,question,answerId,answer,score,rank
    model_scores = []
    with open("output/baseline1_chatglm3_6B.txt", 'a') as f:
        for row in tqdm(row_labels):
            taskId = df.loc[row, "taskId"]
            questionId = df.loc[row, "questionId"]
            question = df.loc[row, "question"]
            answer = df.loc[row, "answer"]

            model_score = get_model_score(taskId, question, answer, model)

            model_scores.append(model_score)

            if len(model_scores) == 7:
                ranks = get_rank(model_scores)
                for i in range(7):
                    answerId = i
                    f.write(f"{taskId} {questionId} {answerId} {model_scores[i]} {ranks[i]}\n")
                model_scores = []
    
if __name__ == '__main__':
    paths = ['test/test_dialog.csv', 'test/test_NFCATS.csv', 'test/test_story.csv', 'test/test_Xsum.csv']
    for path in paths[1:]:
        get_output(path, glm3_6b)