import gradio as gr
import pandas as pd
import numpy as np
import re
from prompt import algebric_prompt, python_prompt
from utils import generate_response, run_code


def generate_algebric_template(question):
    var_names = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
    pattern = re.compile(r"[-+]?\d*\.\d+|\d+")
    var_map = {}
    matches = re.findall(pattern, question)

    for i, num in enumerate(matches):
        var_name = var_names[i]
        question = question.replace(num, var_name)
        var_map[var_name] = float(num)
    return question, var_map


def generate_algebric_expression(question, variables, param, token):
    question = question.strip()
    query = algebric_prompt.format(question=question).strip() + "\n"
    response = generate_response(query, param, token)
    expression = response.split(f"#Ques: {question}")[-1].strip()
    return expression.split("Answer = ")[-1]


def generate_python_code(question, equation, param, token):
    query = python_prompt.format(question=question.strip(), expression=equation.strip()).strip() + "\n"
    response = generate_response(query, param, token)
    function_code = response.split("# Function for above expression is:")[-1].strip()
    return function_code


def run(question, random_candidates, hps, token):
    question, var_map = generate_algebric_template(question)

    # generating the random candidates for arguments
    random_mapping = pd.DataFrame(columns=list(var_map.keys()))

    for _ in range(random_candidates):
        random_mapping.loc[len(random_mapping)] = np.random.randint(1, 100, (len(random_mapping.columns),))

    candidates = []
    acc = []
    # accumulating results
    N = len(hps)
    for i in range(N):

        variables = list(var_map.keys())
        expression = generate_algebric_expression(question, variables, hps[i], token)
        code = generate_python_code(question, expression, hps[i], token)
        candidates.append((expression, code))
        current_acc = 0

        try:
            for idx in range(5):
                arguments = random_mapping.iloc[idx].to_list()

                # running expression
                exp = expression
                temp_code = code

                for k, v in zip(list(var_map.keys()), arguments):
                    exp = exp.replace(k, str(v))
                exp = "print(" + exp + ")"

                if "input(" in exp or "input(" in temp_code:
                    acc.append(0)
                    continue

                exp_ans = run_code(exp)

                # running code
                parameters = temp_code.split("\n")[0].split("def solution")[-1][1:-2].split(",")
                if '' in parameters:
                    parameters.remove('')

                arguments = [(param.strip(), int(random_mapping.iloc[idx][param.strip()])) for param in parameters]
                arg_string = ""
                for param, val in arguments:
                    arg_string += f"{param}={val},"
                func_call = f"\nprint(solution({arg_string[:-1]}))"
                temp_code += func_call
                code_ans = run_code(temp_code)

                current_acc += int(exp_ans == code_ans)

                # reverting the changes
                exp = expression
                temp_code = code
        except Exception as ex:
            pass
        acc.append(current_acc)

    candidate_index = np.argmax(acc)
    top_candidate = candidates[candidate_index]
    return top_candidate, var_map


def solve_mp(question, token):
    hps = [0.9, 0.95]
    (expression, code), var_map = run(question, 5, hps, token)
    exp_op = None
    code_op = None

    try:
        # expression output
        for k, v in var_map.items():
            expression = expression.replace(k, str(v))
        expression = "print(" + expression + ")"
        print(expression)

        if "input(" in expression:
            raise Exception

        exp_op = run_code(expression)
    except:
        print("expression cannot be executed", expression)
    try:
        # code output
        print(code)
        parameters = code.split("\n")[0].split("def solution")[-1][1:-2].split(",")
        if '' in parameters:
            parameters.remove('')

        arguments = [(param.strip(), int(var_map[param.strip()])) for param in parameters]
        arg_string = ""
        for param, val in arguments:
            arg_string += f"{param}={val},"
        func_call = f"\nprint(solution({arg_string[:-1]}))"
        code += func_call
        if "input(" in code:
            print("code cannot be executed")
            raise Exception
        code_op = run_code(code)
    except:
        return None, None, code

    return exp_op, code_op, code