luulinh90s's picture
update
ce41c1d
import uuid
from flask import Flask, render_template, request, redirect, url_for, send_from_directory
import json
import random
import os
import string
import logging
from datetime import datetime
from huggingface_hub import login, HfApi, hf_hub_download
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
])
logger = logging.getLogger(__name__)
# Use the Hugging Face token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
logger.error("HF_TOKEN not found in environment variables")
app = Flask(__name__)
app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
# File-based session storage
SESSION_DIR = '/tmp/sessions'
os.makedirs(SESSION_DIR, exist_ok=True)
# Update the VISUALIZATION_DIRS dictionary
VISUALIZATION_DIRS = {
"No-XAI": "htmls_NO_XAI_mod",
"Dater": "htmls_DATER_mod2",
"Chain-of-Table": "htmls_COT_mod",
"Plan-of-SQLs": "htmls_POS_mod2",
"Text2SQL": "htmls_Text2SQL"
}
# Update the get_method_dir function
def get_method_dir(method):
if method == 'No-XAI':
return 'NO_XAI'
elif method == 'Dater':
return 'DATER'
elif method == 'Chain-of-Table':
return 'COT'
elif method == 'Plan-of-SQLs':
return 'POS'
elif method == 'Text2SQL':
return 'Text2SQL'
else:
return None
# Update the METHODS list
METHODS = ["No-XAI", "Dater", "Chain-of-Table", "Plan-of-SQLs", "Text2SQL"]
def generate_session_id():
return str(uuid.uuid4())
def save_session_data(session_id, data):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
with open(file_path, 'w') as f:
json.dump(data, f)
logger.info(f"Session data saved for session {session_id}")
def load_session_data(session_id):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)
return None
def save_session_data_to_hf(session_id, data):
try:
username = data.get('username', 'unknown')
seed = data.get('seed', 'unknown')
start_time = data.get('start_time', datetime.now().isoformat())
file_name = f'{username}_seed{seed}_{start_time}_{session_id}_session.json'
file_name = "".join(c for c in file_name if c.isalnum() or c in ['_', '-', '.'])
json_data = json.dumps(data, indent=4)
temp_file_path = f"/tmp/{file_name}"
with open(temp_file_path, 'w') as f:
f.write(json_data)
api = HfApi()
repo_path = "session_data_foward_simulation"
api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=f"{repo_path}/{file_name}",
repo_id="luulinh90s/Tabular-LLM-Study-Data",
repo_type="space",
)
os.remove(temp_file_path)
logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
except Exception as e:
logger.exception(f"Error saving session data for session {session_id}: {e}")
#
# def load_samples():
# common_samples = []
# categories = ["TP", "TN", "FP", "FN"]
#
# for category in categories:
# files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
# for method in ["Dater", "Chain-of-Table", "Plan-of-SQLs", "Text2SQL"]:
# method_dir = VISUALIZATION_DIRS[method]
# files &= set(os.listdir(f'{method_dir}/{category}'))
#
# for file in files:
# common_samples.append({'category': category, 'file': file})
#
# logger.info(f"Found {len(common_samples)} common samples across all methods")
# return common_samples
def load_samples(method, metadata):
common_samples = []
categories = ["TP", "TN", "FP", "FN"]
for category in categories:
# files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
method_dir = VISUALIZATION_DIRS[method]
files = set(os.listdir(f'{method_dir}/{category}'))
for file in files:
index = file.split('-')[1].split('.')[0]
metadata_key = f"{get_method_dir(method)}_test-{index}.html"
sample_metadata = metadata.get(metadata_key, {})
common_samples.append({
'category': category,
'file': file,
'metadata': sample_metadata
})
logger.info(f"Found {len(common_samples)} samples for method {method}")
return common_samples
def select_balanced_samples(samples):
try:
# Separate samples into two groups
tp_fp_samples = [s for s in samples if s['category'] in ['TP', 'FP']]
tn_fn_samples = [s for s in samples if s['category'] in ['TN', 'FN']]
# Check if we have enough samples in each group
if len(tp_fp_samples) < 5 or len(tn_fn_samples) < 5:
logger.warning(f"Not enough samples in each category. TP+FP: {len(tp_fp_samples)}, TN+FN: {len(tn_fn_samples)}")
return samples if len(samples) <= 10 else random.sample(samples, 10)
# Select 5 samples from each group
selected_tp_fp = random.sample(tp_fp_samples, 5)
selected_tn_fn = random.sample(tn_fn_samples, 5)
# Combine and shuffle the selected samples
selected_samples = selected_tp_fp + selected_tn_fn
random.shuffle(selected_samples)
logger.info(f"Selected 10 balanced samples: 5 from TP+FP, 5 from TN+FN")
return selected_samples
except Exception as e:
logger.exception("Error selecting balanced samples")
return []
# @app.route('/')
# def introduction():
# return render_template('introduction.html')
@app.route('/attribution')
def attribution():
return render_template('attribution.html')
#
# @app.route('/index', methods=['GET', 'POST'])
# def index():
# if request.method == 'POST':
# username = request.form.get('username')
# seed = request.form.get('seed')
# method = request.form.get('method')
# if not username or not seed or not method:
# return render_template('index.html', error="Please fill in all fields and select a method.")
# if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL']:
# return render_template('index.html', error="Invalid method selected.")
# try:
# seed = int(seed)
# random.seed(seed)
# all_samples = load_samples()
# selected_samples = select_balanced_samples(all_samples)
# if len(selected_samples) == 0:
# return render_template('index.html', error="No common samples were found")
# start_time = datetime.now().isoformat()
# session_id = generate_session_id()
# session_data = {
# 'username': username,
# 'seed': str(seed),
# 'method': method,
# 'selected_samples': selected_samples,
# 'current_index': 0,
# 'responses': [],
# 'start_time': start_time,
# 'session_id': session_id
# }
# save_session_data(session_id, session_data)
# logger.info(f"Session data stored for user {username}, method {method}, session_id {session_id}")
#
# # Redirect to explanation for all methods
# return redirect(url_for('explanation', session_id=session_id))
# except Exception as e:
# logger.exception(f"Error in index route: {e}")
# return render_template('index.html', error="An error occurred. Please try again.")
# return render_template('index.html', show_no_xai=False)
@app.route('/index', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
username = request.form.get('username')
seed = request.form.get('seed')
method = request.form.get('method')
if not username or not seed or not method:
return render_template('index.html', error="Please fill in all fields and select a method.")
if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL', 'No-XAI']:
return render_template('index.html', error="Invalid method selected.")
try:
seed = int(seed)
random.seed(seed)
# Load the appropriate metadata file
if method == "Chain-of-Table":
json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
elif method == "Plan-of-SQLs":
json_file = 'Tabular_LLMs_human_study_vis_6_POS.json'
elif method == "Dater":
json_file = 'Tabular_LLMs_human_study_vis_6_DATER.json'
elif method == "No-XAI":
json_file = 'Tabular_LLMs_human_study_vis_6_NO_XAI.json'
elif method == "Text2SQL":
json_file = 'Tabular_LLMs_human_study_vis_6_Text2SQL.json'
with open(json_file, 'r') as f:
metadata = json.load(f)
all_samples = load_samples(method, metadata)
selected_samples = select_balanced_samples(all_samples)
if len(selected_samples) == 0:
return render_template('index.html', error="No common samples were found")
start_time = datetime.now().isoformat()
session_id = generate_session_id()
session_data = {
'username': username,
'seed': str(seed),
'method': method,
'selected_samples': selected_samples,
'current_index': 0,
'responses': [],
'start_time': start_time,
'session_id': session_id
}
save_session_data(session_id, session_data)
logger.info(f"Session data stored for user {username}, method {method}, session_id {session_id}")
# Redirect to explanation for all methods
return redirect(url_for('explanation', session_id=session_id))
except Exception as e:
logger.exception(f"Error in index route: {e}")
return render_template('index.html', error="An error occurred. Please try again.")
return render_template('index.html', show_no_xai=False)
@app.route('/explanation/<session_id>')
def explanation(session_id):
session_data = load_session_data(session_id)
if not session_data:
logger.error(f"No session data found for session ID: {session_id}")
return redirect(url_for('index'))
method = session_data.get('method')
if not method:
logger.error(f"No method found in session data for session ID: {session_id}")
return redirect(url_for('index'))
if method == 'Chain-of-Table':
return render_template('cot_intro.html', session_id=session_id)
elif method == 'Plan-of-SQLs':
return render_template('pos_intro.html', session_id=session_id)
elif method == 'Dater':
return render_template('dater_intro.html', session_id=session_id)
elif method == 'Text2SQL':
return render_template('text2sql_intro.html', session_id=session_id)
else:
logger.error(f"Invalid method '{method}' for session ID: {session_id}")
return redirect(url_for('index'))
#
# @app.route('/experiment/<session_id>', methods=['GET', 'POST'])
# def experiment(session_id):
# try:
# session_data = load_session_data(session_id)
# if not session_data:
# return redirect(url_for('index'))
#
# selected_samples = session_data['selected_samples']
#
# method = session_data['method']
# current_index = session_data['current_index']
#
# if current_index >= len(selected_samples):
# return redirect(url_for('completed', session_id=session_id))
#
# sample = selected_samples[current_index]
# visualization_dir = VISUALIZATION_DIRS[method]
# visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
#
# statement = """
# Please note that in select row function, starting index is 0 for Chain-of-Table and 1 for Dater and Index * represents the selection for all rows.
# """
#
# return render_template('experiment.html',
# sample_id=current_index,
# statement=statement,
# visualization=url_for('send_visualization', filename=visualization_path),
# session_id=session_id,
# method=method)
# except Exception as e:
# logger.exception(f"An error occurred in the experiment route: {e}")
# return "An error occurred", 500
@app.route('/experiment/<session_id>', methods=['GET', 'POST'])
def experiment(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
return redirect(url_for('index'))
selected_samples = session_data['selected_samples']
method = session_data['method']
current_index = session_data['current_index']
if current_index >= len(selected_samples):
return redirect(url_for('completed', session_id=session_id))
sample = selected_samples[current_index]
visualization_dir = VISUALIZATION_DIRS[method]
visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
# Extract metadata
metadata = sample.get('metadata', {})
# Log the metadata
logger.info(f"Sample metadata for session {session_id}, method {method}, index {current_index}: {metadata}")
statement = metadata['statement']
if method == 'Text2SQL':
statement = f""
return render_template('experiment.html',
sample_id=current_index,
statement=statement,
visualization=url_for('send_visualization', filename=visualization_path),
session_id=session_id,
method=method,
metadata=metadata) # Pass metadata to the template
except Exception as e:
logger.exception(f"An error occurred in the experiment route: {e}")
return "An error occurred", 500
@app.route('/')
def root():
return redirect(url_for('consent'))
@app.route('/consent', methods=['GET', 'POST'])
def consent():
if request.method == 'POST':
# User has agreed to the consent
return redirect(url_for('introduction'))
return render_template('consent.html')
@app.route('/introduction')
def introduction():
return render_template('introduction.html')
@app.route('/subjective/<session_id>', methods=['GET', 'POST'])
def subjective(session_id):
if request.method == 'POST':
understanding = request.form.get('understanding')
session_data = load_session_data(session_id)
if not session_data:
logger.error(f"No session data found for session: {session_id}")
return redirect(url_for('index'))
session_data['subjective_feedback'] = understanding
save_session_data(session_id, session_data)
return redirect(url_for('completed', session_id=session_id))
return render_template('subjective.html', session_id=session_id)
@app.route('/feedback', methods=['POST'])
def feedback():
try:
session_id = request.form['session_id']
prediction = request.form['prediction']
session_data = load_session_data(session_id)
if not session_data:
logger.error(f"No session data found for session: {session_id}")
return redirect(url_for('index'))
session_data['responses'].append({
'sample_id': session_data['current_index'],
'user_prediction': prediction
})
session_data['current_index'] += 1
save_session_data(session_id, session_data)
logger.info(f"Prediction saved for session {session_id}, sample {session_data['current_index'] - 1}")
if session_data['current_index'] >= len(session_data['selected_samples']):
return redirect(url_for('subjective', session_id=session_id))
return redirect(url_for('experiment', session_id=session_id))
except Exception as e:
logger.exception(f"Error in feedback route: {e}")
return "An error occurred", 500
# Update the completed route to include Text2SQL
@app.route('/completed/<session_id>')
def completed(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
logger.error(f"No session data found for session: {session_id}")
return redirect(url_for('index'))
session_data['end_time'] = datetime.now().isoformat()
responses = session_data['responses']
method = session_data['method']
if method == "Chain-of-Table":
json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
elif method == "Plan-of-SQLs":
json_file = 'Tabular_LLMs_human_study_vis_6_POS.json'
elif method == "Dater":
json_file = 'Tabular_LLMs_human_study_vis_6_DATER.json'
elif method == "No-XAI":
json_file = 'Tabular_LLMs_human_study_vis_6_NO_XAI.json'
elif method == "Text2SQL":
json_file = 'Tabular_LLMs_human_study_vis_6_Text2SQL.json'
else:
return "Invalid method", 400
with open(json_file, 'r') as f:
ground_truth = json.load(f)
correct_predictions = 0
true_predictions = 0
false_predictions = 0
for response in responses:
sample_id = response['sample_id']
user_prediction = response['user_prediction']
visualization_file = session_data['selected_samples'][sample_id]['file']
index = visualization_file.split('-')[1].split('.')[0]
ground_truth_key = f"{get_method_dir(method)}_test-{index}.html"
logger.info(f"ground_truth_key: {ground_truth_key}")
if ground_truth_key in ground_truth:
# TODO: Important Note ->
# Using model prediction as we are doing forward simulation
# Please use ground_truth[ground_truth_key]['answer'].upper() if running verification task
model_prediction = ground_truth[ground_truth_key]['prediction'].upper()
if user_prediction.upper() == model_prediction:
correct_predictions += 1
if user_prediction.upper() == "TRUE":
true_predictions += 1
elif user_prediction.upper() == "FALSE":
false_predictions += 1
else:
logger.warning(f"Missing key in ground truth: {ground_truth_key}")
accuracy = (correct_predictions / len(responses)) * 100 if responses else 0
accuracy = round(accuracy, 2)
true_percentage = (true_predictions / len(responses)) * 100 if len(responses) else 0
false_percentage = (false_predictions / len(responses)) * 100 if len(responses) else 0
true_percentage = round(true_percentage, 2)
false_percentage = round(false_percentage, 2)
session_data['accuracy'] = accuracy
session_data['true_percentage'] = true_percentage
session_data['false_percentage'] = false_percentage
# Save all the data to Hugging Face at the end
save_session_data_to_hf(session_id, session_data)
# Remove the local session data file
os.remove(os.path.join(SESSION_DIR, f'{session_id}.json'))
return render_template('completed.html',
accuracy=accuracy,
true_percentage=true_percentage,
false_percentage=false_percentage)
except Exception as e:
logger.exception(f"An error occurred in the completed route: {e}")
return "An error occurred", 500
@app.route('/visualizations/<path:filename>')
def send_visualization(filename):
logger.info(f"Attempting to serve file: {filename}")
base_dir = os.getcwd()
file_path = os.path.normpath(os.path.join(base_dir, filename))
if not file_path.startswith(base_dir):
return "Access denied", 403
if not os.path.exists(file_path):
return "File not found", 404
directory = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
logger.info(f"Serving file from directory: {directory}, filename: {file_name}")
return send_from_directory(directory, file_name)
@app.route('/visualizations/<path:filename>')
def send_examples(filename):
return send_from_directory('', filename)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)