import uuid
from flask import Flask, render_template, request, redirect, url_for, send_from_directory
import json
import random
import os
import string
import logging
from datetime import datetime
from huggingface_hub import login, HfApi, hf_hub_download

# Set up logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler("app.log"),
                        logging.StreamHandler()
                    ])
logger = logging.getLogger(__name__)

# Use the Hugging Face token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
else:
    logger.error("HF_TOKEN not found in environment variables")

app = Flask(__name__)
app.config['SECRET_KEY'] = 'supersecretkey'  # Change this to a random secret key

# File-based session storage
SESSION_DIR = '/tmp/sessions'
os.makedirs(SESSION_DIR, exist_ok=True)

# Update the VISUALIZATION_DIRS dictionary
VISUALIZATION_DIRS = {
    "No-XAI": "htmls_NO_XAI_mod",
    "Dater": "htmls_DATER_mod2",
    "Chain-of-Table": "htmls_COT_mod",
    "Plan-of-SQLs": "htmls_POS_mod2",
    "Text2SQL": "htmls_Text2SQL"
}

# Update the get_method_dir function
def get_method_dir(method):
    if method == 'No-XAI':
        return 'NO_XAI'
    elif method == 'Dater':
        return 'DATER'
    elif method == 'Chain-of-Table':
        return 'COT'
    elif method == 'Plan-of-SQLs':
        return 'POS'
    elif method == 'Text2SQL':
        return 'Text2SQL'
    else:
        return None

# Update the METHODS list
METHODS = ["No-XAI", "Dater", "Chain-of-Table", "Plan-of-SQLs", "Text2SQL"]

def generate_session_id():
    return str(uuid.uuid4())

def save_session_data(session_id, data):
    file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
    with open(file_path, 'w') as f:
        json.dump(data, f)
    logger.info(f"Session data saved for session {session_id}")

def load_session_data(session_id):
    file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            return json.load(f)
    return None

def save_session_data_to_hf(session_id, data):
    try:
        username = data.get('username', 'unknown')
        seed = data.get('seed', 'unknown')
        start_time = data.get('start_time', datetime.now().isoformat())
        file_name = f'{username}_seed{seed}_{start_time}_{session_id}_session.json'
        file_name = "".join(c for c in file_name if c.isalnum() or c in ['_', '-', '.'])

        json_data = json.dumps(data, indent=4)
        temp_file_path = f"/tmp/{file_name}"
        with open(temp_file_path, 'w') as f:
            f.write(json_data)

        api = HfApi()
        repo_path = "session_data_foward_simulation"

        api.upload_file(
            path_or_fileobj=temp_file_path,
            path_in_repo=f"{repo_path}/{file_name}",
            repo_id="luulinh90s/Tabular-LLM-Study-Data",
            repo_type="space",
        )
        os.remove(temp_file_path)
        logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
    except Exception as e:
        logger.exception(f"Error saving session data for session {session_id}: {e}")

def load_samples():
    common_samples = []
    categories = ["TP", "TN", "FP", "FN"]

    for category in categories:
        files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
        for method in ["Dater", "Chain-of-Table", "Plan-of-SQLs"]:
            method_dir = VISUALIZATION_DIRS[method]
            files &= set(os.listdir(f'{method_dir}/{category}'))

        for file in files:
            common_samples.append({'category': category, 'file': file})

    logger.info(f"Found {len(common_samples)} common samples across all methods")
    return common_samples

def select_balanced_samples(samples):
    try:
        # Separate samples into two groups
        tp_fp_samples = [s for s in samples if s['category'] in ['TP', 'FP']]
        tn_fn_samples = [s for s in samples if s['category'] in ['TN', 'FN']]

        # Check if we have enough samples in each group
        if len(tp_fp_samples) < 5 or len(tn_fn_samples) < 5:
            logger.warning(f"Not enough samples in each category. TP+FP: {len(tp_fp_samples)}, TN+FN: {len(tn_fn_samples)}")
            return samples if len(samples) <= 10 else random.sample(samples, 10)

        # Select 5 samples from each group
        selected_tp_fp = random.sample(tp_fp_samples, 5)
        selected_tn_fn = random.sample(tn_fn_samples, 5)

        # Combine and shuffle the selected samples
        selected_samples = selected_tp_fp + selected_tn_fn
        random.shuffle(selected_samples)

        logger.info(f"Selected 10 balanced samples: 5 from TP+FP, 5 from TN+FN")
        return selected_samples
    except Exception as e:
        logger.exception("Error selecting balanced samples")
        return []

# @app.route('/')
# def introduction():
#     return render_template('introduction.html')

@app.route('/attribution')
def attribution():
    return render_template('attribution.html')

@app.route('/index', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        username = request.form.get('username')
        seed = request.form.get('seed')
        method = request.form.get('method')
        if not username or not seed or not method:
            return render_template('index.html', error="Please fill in all fields and select a method.")
        if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL']:
            return render_template('index.html', error="Invalid method selected.")
        try:
            seed = int(seed)
            random.seed(seed)
            all_samples = load_samples()
            selected_samples = select_balanced_samples(all_samples)
            if len(selected_samples) == 0:
                return render_template('index.html', error="No common samples were found")
            start_time = datetime.now().isoformat()
            session_id = generate_session_id()
            session_data = {
                'username': username,
                'seed': str(seed),
                'method': method,
                'selected_samples': selected_samples,
                'current_index': 0,
                'responses': [],
                'start_time': start_time,
                'session_id': session_id
            }
            save_session_data(session_id, session_data)
            logger.info(f"Session data stored for user {username}, method {method}, session_id {session_id}")

            # Redirect to explanation for all methods
            return redirect(url_for('explanation', session_id=session_id))
        except Exception as e:
            logger.exception(f"Error in index route: {e}")
            return render_template('index.html', error="An error occurred. Please try again.")
    return render_template('index.html', show_no_xai=False)

@app.route('/explanation/<session_id>')
def explanation(session_id):
    session_data = load_session_data(session_id)
    if not session_data:
        logger.error(f"No session data found for session ID: {session_id}")
        return redirect(url_for('index'))

    method = session_data.get('method')
    if not method:
        logger.error(f"No method found in session data for session ID: {session_id}")
        return redirect(url_for('index'))

    if method == 'Chain-of-Table':
        return render_template('cot_intro.html', session_id=session_id)
    elif method == 'Plan-of-SQLs':
        return render_template('pos_intro.html', session_id=session_id)
    elif method == 'Dater':
        return render_template('dater_intro.html', session_id=session_id)
    elif method == 'Text2SQL':
        return render_template('text2sql_intro.html', session_id=session_id)
    else:
        logger.error(f"Invalid method '{method}' for session ID: {session_id}")
        return redirect(url_for('index'))

@app.route('/experiment/<session_id>', methods=['GET', 'POST'])
def experiment(session_id):
    try:
        session_data = load_session_data(session_id)
        if not session_data:
            return redirect(url_for('index'))

        selected_samples = session_data['selected_samples']
        method = session_data['method']
        current_index = session_data['current_index']

        if current_index >= len(selected_samples):
            return redirect(url_for('completed', session_id=session_id))

        sample = selected_samples[current_index]
        visualization_dir = VISUALIZATION_DIRS[method]
        visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"

        statement = """
Please note that in select row function, starting index is 0 for Chain-of-Table and 1 for Dater and Index * represents the selection for all rows.
        """

        return render_template('experiment.html',
                               sample_id=current_index,
                               statement=statement,
                               visualization=url_for('send_visualization', filename=visualization_path),
                               session_id=session_id,
                               method=method)
    except Exception as e:
        logger.exception(f"An error occurred in the experiment route: {e}")
        return "An error occurred", 500

@app.route('/')
def root():
    return redirect(url_for('consent'))

@app.route('/consent', methods=['GET', 'POST'])
def consent():
    if request.method == 'POST':
        # User has agreed to the consent
        return redirect(url_for('introduction'))
    return render_template('consent.html')

@app.route('/introduction')
def introduction():
    return render_template('introduction.html')

@app.route('/subjective/<session_id>', methods=['GET', 'POST'])
def subjective(session_id):
    if request.method == 'POST':
        understanding = request.form.get('understanding')

        session_data = load_session_data(session_id)
        if not session_data:
            logger.error(f"No session data found for session: {session_id}")
            return redirect(url_for('index'))

        session_data['subjective_feedback'] = understanding
        save_session_data(session_id, session_data)

        return redirect(url_for('completed', session_id=session_id))

    return render_template('subjective.html', session_id=session_id)

@app.route('/feedback', methods=['POST'])
def feedback():
    try:
        session_id = request.form['session_id']
        prediction = request.form['prediction']

        session_data = load_session_data(session_id)
        if not session_data:
            logger.error(f"No session data found for session: {session_id}")
            return redirect(url_for('index'))

        session_data['responses'].append({
            'sample_id': session_data['current_index'],
            'user_prediction': prediction
        })

        session_data['current_index'] += 1
        save_session_data(session_id, session_data)
        logger.info(f"Prediction saved for session {session_id}, sample {session_data['current_index'] - 1}")

        if session_data['current_index'] >= len(session_data['selected_samples']):
            return redirect(url_for('subjective', session_id=session_id))

        return redirect(url_for('experiment', session_id=session_id))
    except Exception as e:
        logger.exception(f"Error in feedback route: {e}")
        return "An error occurred", 500

# Update the completed route to include Text2SQL
@app.route('/completed/<session_id>')
def completed(session_id):
    try:
        session_data = load_session_data(session_id)
        if not session_data:
            logger.error(f"No session data found for session: {session_id}")
            return redirect(url_for('index'))

        session_data['end_time'] = datetime.now().isoformat()
        responses = session_data['responses']
        method = session_data['method']

        if method == "Chain-of-Table":
            json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
        elif method == "Plan-of-SQLs":
            json_file = 'Tabular_LLMs_human_study_vis_6_POS.json'
        elif method == "Dater":
            json_file = 'Tabular_LLMs_human_study_vis_6_DATER.json'
        elif method == "No-XAI":
            json_file = 'Tabular_LLMs_human_study_vis_6_NO_XAI.json'
        elif method == "Text2SQL":
            json_file = 'Tabular_LLMs_human_study_vis_6_Text2SQL.json'
        else:
            return "Invalid method", 400

        with open(json_file, 'r') as f:
            ground_truth = json.load(f)

        correct_predictions = 0
        true_predictions = 0
        false_predictions = 0

        for response in responses:
            sample_id = response['sample_id']
            user_prediction = response['user_prediction']
            visualization_file = session_data['selected_samples'][sample_id]['file']
            index = visualization_file.split('-')[1].split('.')[0]

            ground_truth_key = f"{get_method_dir(method)}_test-{index}.html"
            logger.info(f"ground_truth_key: {ground_truth_key}")

            if ground_truth_key in ground_truth:
                # TODO: Important Note ->
                # Using model prediction as we are doing forward simulation
                # Please use ground_truth[ground_truth_key]['answer'].upper() if running verification task
                model_prediction = ground_truth[ground_truth_key]['prediction'].upper()
                if user_prediction.upper() == model_prediction:
                    correct_predictions += 1

                if user_prediction.upper() == "TRUE":
                    true_predictions += 1
                elif user_prediction.upper() == "FALSE":
                    false_predictions += 1
            else:
                logger.warning(f"Missing key in ground truth: {ground_truth_key}")

        accuracy = (correct_predictions / len(responses)) * 100 if responses else 0
        accuracy = round(accuracy, 2)

        true_percentage = (true_predictions / len(responses)) * 100 if len(responses) else 0
        false_percentage = (false_predictions / len(responses)) * 100 if len(responses) else 0

        true_percentage = round(true_percentage, 2)
        false_percentage = round(false_percentage, 2)

        session_data['accuracy'] = accuracy
        session_data['true_percentage'] = true_percentage
        session_data['false_percentage'] = false_percentage

        # Save all the data to Hugging Face at the end
        save_session_data_to_hf(session_id, session_data)

        # Remove the local session data file
        os.remove(os.path.join(SESSION_DIR, f'{session_id}.json'))

        return render_template('completed.html',
                               accuracy=accuracy,
                               true_percentage=true_percentage,
                               false_percentage=false_percentage)
    except Exception as e:
        logger.exception(f"An error occurred in the completed route: {e}")
        return "An error occurred", 500

@app.route('/visualizations/<path:filename>')
def send_visualization(filename):
    logger.info(f"Attempting to serve file: {filename}")
    base_dir = os.getcwd()
    file_path = os.path.normpath(os.path.join(base_dir, filename))
    if not file_path.startswith(base_dir):
        return "Access denied", 403

    if not os.path.exists(file_path):
        return "File not found", 404

    directory = os.path.dirname(file_path)
    file_name = os.path.basename(file_path)
    logger.info(f"Serving file from directory: {directory}, filename: {file_name}")
    return send_from_directory(directory, file_name)

@app.route('/visualizations/<path:filename>')
def send_examples(filename):
    return send_from_directory('', filename)

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=True)