SalazarPevelll
be
f291f4a
raw
history blame
24.2 kB
from flask import request, Flask, jsonify, make_response
from flask_cors import CORS, cross_origin
from PIL import Image, ImageDraw, ImageFont
import base64
import os
import sys
import json
import pickle
import io
import numpy as np
import gc
import shutil
from utils import update_epoch_projection, initialize_backend, add_line,getCriticalChangeIndices, getConfChangeIndices, getContraVisChangeIndices, getContraVisChangeIndicesSingle
import time
# flask for API server
app = Flask(__name__)
cors = CORS(app, supports_credentials=True)
app.config['CORS_HEADERS'] = 'Content-Type'
API_result_path = "./admin_API_result.csv"
@app.route('/updateProjection', methods=["POST", "GET"])
@cross_origin()
def update_projection():
res = request.get_json()
start_time = time.time()
CONTENT_PATH = os.path.normpath(res['path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
iteration = int(res['iteration'])
predicates = res["predicates"]
username = res['username']
isContraVis = res['isContraVis']
# sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
# use the true one
# EPOCH = (iteration-1)*context.strategy.data_provider.p + context.strategy.data_provider.s
EPOCH = int(iteration)
embedding_2d, grid, decision_view, label_name_dict, label_color_list, label_list, max_iter, training_data_index, \
testing_data_index, eval_new, prediction_list, selected_points, properties, highlightedPointIndices = update_epoch_projection(context, EPOCH, predicates, isContraVis)
if (len(highlightedPointIndices) != 0):
highlightedPointIndices = highlightedPointIndices.tolist()
end_time = time.time()
elapsed_time = end_time - start_time
print("updateprojection", elapsed_time)
# sys.path.remove(CONTENT_PATH)
# add_line(API_result_path,['TT',username])
return make_response(jsonify({'result': embedding_2d,
'grid_index': grid.tolist(),
'grid_color': 'data:image/png;base64,' + decision_view,
'label_name_dict':label_name_dict,
'label_color_list': label_color_list,
'label_list': label_list,
'maximum_iteration': max_iter,
'training_data': training_data_index,
'testing_data': testing_data_index,
'evaluation': eval_new,
'prediction_list': prediction_list,
"selectedPoints":selected_points.tolist(),
"properties":properties.tolist(),
"highlightedPointIndices": highlightedPointIndices
}), 200)
@app.route('/highlightCriticalChange', methods=["POST", "GET"])
@cross_origin()
def highlight_critical_change():
res = request.get_json()
CONTENT_PATH = os.path.normpath(res['path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
curr_iteration = int(res['iteration'])
last_iteration = int(res['last_iteration'])
username = res['username']
# sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
predChangeIndices = getCriticalChangeIndices(context, curr_iteration, last_iteration)
# sys.path.remove(CONTENT_PATH)
# add_line(API_result_path,['TT',username])
return make_response(jsonify({
"predChangeIndices": predChangeIndices.tolist()
}), 200)
@app.route('/contraVisHighlight', methods=["POST", "GET"])
@cross_origin()
def contravis_highlight():
res = request.get_json()
CONTENT_PATH = os.path.normpath(res['path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
curr_iteration = int(res['iterationLeft'])
last_iteration = int(res['iterationRight'])
method = res['method']
username = res['username']
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
contraVisChangeIndices = getContraVisChangeIndices(context, curr_iteration, last_iteration, method)
print(len(contraVisChangeIndices))
return make_response(jsonify({
"contraVisChangeIndices": contraVisChangeIndices
}), 200)
@app.route('/contraVisHighlightSingle', methods=["POST", "GET"])
@cross_origin()
def contravis_highlight_single():
start_time = time.time()
res = request.get_json()
CONTENT_PATH = os.path.normpath(res['path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
curr_iteration = int(res['iterationLeft'])
last_iteration = int(res['iterationRight'])
method = res['method']
left_selected = res['selectedPointLeft']
right_selected = res['selectedPointRight']
username = res['username']
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
contraVisChangeIndicesLeft, contraVisChangeIndicesRight, contraVisChangeIndicesLeftLeft, contraVisChangeIndicesLeftRight, contraVisChangeIndicesRightLeft, contraVisChangeIndicesRightRight = getContraVisChangeIndicesSingle(context, curr_iteration, last_iteration, method, left_selected, right_selected)
end_time = time.time()
elapsed_time = end_time - start_time
print(elapsed_time)
return make_response(jsonify({
"contraVisChangeIndicesLeft": contraVisChangeIndicesLeft,
"contraVisChangeIndicesRight": contraVisChangeIndicesRight,
"contraVisChangeIndicesLeftLeft": contraVisChangeIndicesLeftLeft,
"contraVisChangeIndicesLeftRight": contraVisChangeIndicesLeftRight,
"contraVisChangeIndicesRightLeft": contraVisChangeIndicesRightLeft,
"contraVisChangeIndicesRightRight": contraVisChangeIndicesRightRight
}), 200)
@app.route('/highlightConfChange', methods=["POST", "GET"])
@cross_origin()
def highlight_conf_change():
res = request.get_json()
CONTENT_PATH = os.path.normpath(res['path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
curr_iteration = int(res['iteration'])
last_iteration = int(res['last_iteration'])
confChangeInput = float(res['confChangeInput'])
print(confChangeInput)
username = res['username']
# sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
confChangeIndices = getConfChangeIndices(context, curr_iteration, last_iteration, confChangeInput)
print(confChangeIndices)
# sys.path.remove(CONTENT_PATH)
# add_line(API_result_path,['TT',username])
return make_response(jsonify({
"confChangeIndices": confChangeIndices.tolist()
}), 200)
@app.route('/query', methods=["POST"])
@cross_origin()
def filter():
start_time = time.time()
res = request.get_json()
CONTENT_PATH = os.path.normpath(res['content_path'])
VIS_METHOD = res['vis_method']
SETTING = res["setting"]
iteration = int(res['iteration'])
predicates = res["predicates"]
username = res['username']
sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
# TODO: fix when active learning
EPOCH = (iteration-1)*context.strategy.data_provider.p + context.strategy.data_provider.s
training_data_number = context.strategy.config["TRAINING"]["train_num"]
testing_data_number = context.strategy.config["TRAINING"]["test_num"]
current_index = context.get_epoch_index(EPOCH)
selected_points = np.arange(training_data_number)[current_index]
selected_points = np.concatenate((selected_points, np.arange(training_data_number, training_data_number + testing_data_number, 1)), axis=0)
# selected_points = np.arange(training_data_number + testing_data_number)
for key in predicates.keys():
if key == "label":
tmp = np.array(context.filter_label(predicates[key], int(EPOCH)))
elif key == "type":
tmp = np.array(context.filter_type(predicates[key], int(EPOCH)))
elif key == "confidence":
tmp = np.array(context.filter_conf(predicates[key][0],predicates[key][1],int(EPOCH)))
else:
tmp = np.arange(training_data_number + testing_data_number)
selected_points = np.intersect1d(selected_points, tmp)
sys.path.remove(CONTENT_PATH)
add_line(API_result_path,['SQ',username])
end_time = time.time()
elapsed_time = end_time - start_time
print("query", elapsed_time)
return make_response(jsonify({"selectedPoints": selected_points.tolist()}), 200)
# base64
@app.route('/spriteImage', methods=["POST","GET"])
@cross_origin()
def sprite_image():
path = request.args.get("path")
index = request.args.get("index")
username = request.args.get("username")
CONTENT_PATH = os.path.normpath(path)
print('index', index)
idx = int(index)
pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
img_stream = ''
with open(pic_save_dir_path, 'rb') as img_f:
img_stream = img_f.read()
img_stream = base64.b64encode(img_stream).decode()
add_line(API_result_path,['SI',username])
return make_response(jsonify({"imgUrl":'data:image/png;base64,' + img_stream}), 200)
@app.route('/spriteText', methods=["POST","GET"])
@cross_origin()
def sprite_text():
path = request.args.get("path")
index = request.args.get("index")
username = request.args.get("username")
iteration = request.args.get("iteration")
# Adjust font path as needed. Use a path to a .ttf file on your system, or remove the 'truetype' part to use a default font.
# Load font - ensure 'arial.ttf' is available at this path or use a default font
# try:
# font = ImageFont.truetype("arial.ttf", 15)
# except IOError:
# font = ImageFont.load_default()
# # Calculate image size dynamically based on text length
# text_width, text_height = font.getsize(text)
# image_size = (text_width, text_height) # Add some padding
# # Create an image
# background_color = "white"
# font_color = "black"
# image = Image.new("RGB", image_size, background_color)
# draw = ImageDraw.Draw(image)
# draw.text((1, 1), text, fill=font_color, font=font) # Start drawing the text from a small margin
# # Save the image to a BytesIO object
# img_io = io.BytesIO()
# image.save(img_io, 'PNG')
# img_io.seek(0)
# Assuming you have a function to get sprite texts
# sprite_texts = get_sprite_texts(CONTENT_PATH, idx)
# # Include both the image and texts in the response
# response_data = {
# "texts": sprite_texts
# }
# return make_response(jsonify(response_data), 200)
CONTENT_PATH = os.path.normpath(path)
idx = int(index)
start = time.time()
# text_save_dir_path = os.path.join(CONTENT_PATH, f"/Model/Epoch_{iteration}/labels", "text_{}.txt".format(idx))
text_save_dir_path = os.path.join(CONTENT_PATH, f"Model/Epoch_{iteration}/labels", f"text_{idx}.txt")
if os.path.exists(text_save_dir_path):
with open(text_save_dir_path, 'r') as text_f:
# Read the contents of the file and store it in sprite_texts
sprite_texts = text_f.read()
else:
print("File does not exist:", text_save_dir_path)
print(sprite_texts)
response_data = {
"texts": sprite_texts
}
end = time.time()
print("processTime", end-start)
return make_response(jsonify(response_data), 200)
# img_stream = ''
# with open(text_save_dir_path, 'rb') as img_f:
# img_stream = img_f.read()
# img_stream = base64.b64encode(img_stream).decode()
# img_stream = base64.b64encode(img_io.getvalue()).decode()
# Return the base64-encoded image as JSON
# return make_response(jsonify({"imgUrl": 'data:image/png;base64,' + img_stream}), 200)
# @app.route('/spriteList', methods=["POST"])
# @cross_origin()
# def sprite_list_image():
# data = request.get_json()
# indices = data["index"]
# path = data["path"]
# CONTENT_PATH = os.path.normpath(path)
# length = len(indices)
# urlList = {}
# start_time = time.time()
# for i in range(length):
# idx = indices[i]
# pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
# img_stream = ''
# with open(pic_save_dir_path, 'rb') as img_f:
# img_stream = img_f.read()
# img_stream = base64.b64encode(img_stream).decode()
# urlList[idx] = 'data:image/png;base64,' + img_stream
# # urlList.append('data:image/png;base64,' + img_stream)
# end_time = time.time()
# elapsed_time = end_time - start_time
# print("Spritelist", elapsed_time)
# return make_response(jsonify({"urlList":urlList}), 200)
@app.route('/spriteList', methods=["POST"])
@cross_origin()
def sprite_list_image():
data = request.get_json()
indices = data["index"]
path = data["path"]
CONTENT_PATH = os.path.normpath(path)
length = len(indices)
urlList = {}
start_time = time.time()
for i in range(length):
idx = indices[i]
pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
img_stream = ''
with open(pic_save_dir_path, 'rb') as img_f:
img_stream = img_f.read()
img_stream = base64.b64encode(img_stream).decode()
urlList[idx] = 'data:image/png;base64,' + img_stream
# urlList.append('data:image/png;base64,' + img_stream)
end_time = time.time()
elapsed_time = end_time - start_time
print("Spritelist", elapsed_time)
return make_response(jsonify({"urlList":urlList}), 200)
@app.route('/al_query', methods=["POST"])
@cross_origin()
def al_query():
data = request.get_json()
CONTENT_PATH = os.path.normpath(data['content_path'])
VIS_METHOD = data['vis_method']
SETTING = data["setting"]
# TODO fix iteration, align with frontend
iteration = data["iteration"]
strategy = data["strategy"]
budget = int(data["budget"])
acc_idxs = data["accIndices"]
rej_idxs = data["rejIndices"]
user_name = data["username"]
isRecommend = data["isRecommend"]
sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING, dense=True)
# TODO add new sampling rule
indices, labels, scores = context.al_query(iteration, budget, strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64))
sort_i = np.argsort(-scores)
indices = indices[sort_i]
labels = labels[sort_i]
scores = scores[sort_i]
sys.path.remove(CONTENT_PATH)
if not isRecommend:
add_line(API_result_path,['Feedback', user_name])
else:
add_line(API_result_path,['Recommend', user_name])
return make_response(jsonify({"selectedPoints": indices.tolist(), "scores": scores.tolist(), "suggestLabels":labels.tolist()}), 200)
@app.route('/anomaly_query', methods=["POST"])
@cross_origin()
def anomaly_query():
data = request.get_json()
CONTENT_PATH = os.path.normpath(data['content_path'])
VIS_METHOD = data['vis_method']
SETTING = data["setting"]
budget = int(data["budget"])
strategy = data["strategy"]
acc_idxs = data["accIndices"]
rej_idxs = data["rejIndices"]
user_name = data["username"]
isRecommend = data["isRecommend"]
sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
context.save_acc_and_rej(acc_idxs, rej_idxs, user_name)
indices, scores, labels = context.suggest_abnormal(strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64), budget)
clean_list,_ = context.suggest_normal(strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64), 1)
sort_i = np.argsort(-scores)
indices = indices[sort_i]
labels = labels[sort_i]
scores = scores[sort_i]
sys.path.remove(CONTENT_PATH)
if not isRecommend:
add_line(API_result_path,['Feedback', user_name])
else:
add_line(API_result_path,['Recommend', user_name])
return make_response(jsonify({"selectedPoints": indices.tolist(), "scores": scores.tolist(), "suggestLabels":labels.tolist(),"cleanList":clean_list.tolist()}), 200)
@app.route('/al_train', methods=["POST"])
@cross_origin()
def al_train():
data = request.get_json()
CONTENT_PATH = os.path.normpath(data['content_path'])
VIS_METHOD = data['vis_method']
SETTING = data["setting"]
acc_idxs = data["accIndices"]
rej_idxs = data["rejIndices"]
iteration = data["iteration"]
user_name = data["username"]
sys.path.append(CONTENT_PATH)
# default setting al_train is light version, we only save the last epoch
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
context.save_acc_and_rej(iteration, acc_idxs, rej_idxs, user_name)
context.al_train(iteration, acc_idxs)
NEW_ITERATION = context.get_max_iter()
context.vis_train(NEW_ITERATION, iteration)
# update iteration projection
embedding_2d, grid, decision_view, label_name_dict, label_color_list, label_list, _, training_data_index, \
testing_data_index, eval_new, prediction_list, selected_points, properties = update_epoch_projection(context, NEW_ITERATION, dict())
# rewirte json =========
res_json_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
with open(res_json_path,encoding='utf8')as fp:
json_data = json.load(fp)
json_data.append({'value': NEW_ITERATION, 'name': 'iteration', 'pid': iteration})
print('json_data',json_data)
with open(res_json_path,'w')as r:
json.dump(json_data, r)
r.close()
# rewirte json =========
del config
gc.collect()
sys.path.remove(CONTENT_PATH)
add_line(API_result_path,['al_train', user_name])
return make_response(jsonify({'result': embedding_2d, 'grid_index': grid, 'grid_color': 'data:image/png;base64,' + decision_view,
'label_name_dict': label_name_dict,
'label_color_list': label_color_list, 'label_list': label_list,
'maximum_iteration': NEW_ITERATION, 'training_data': training_data_index,
'testing_data': testing_data_index, 'evaluation': eval_new,
'prediction_list': prediction_list,
"selectedPoints":selected_points.tolist(),
"properties":properties.tolist()}), 200)
def clear_cache(con_paths):
for CONTENT_PATH in con_paths.values():
ac_flag = False
target_path = os.path.join(CONTENT_PATH, "Model")
dir_list = os.listdir(target_path)
for dir in dir_list:
if "Iteration_" in dir:
ac_flag=True
i = int(dir.replace("Iteration_", ""))
if i > 2:
shutil.rmtree(os.path.join(target_path, dir))
if ac_flag:
iter_structure_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
with open(iter_structure_path, "r") as f:
i_s = json.load(f)
new_is = list()
for item in i_s:
value = item["value"]
if value < 3:
new_is.append(item)
with open(iter_structure_path, "w") as f:
json.dump(new_is, f)
print("Successfully remove cache data!")
@app.route('/login', methods=["POST"])
@cross_origin()
def login():
data = request.get_json()
# username = data["username"]
# password = data["password"]
content_path = data["content_path"]
# clear_cache(con_paths)
# Verify username and password
return make_response(jsonify({"normal_content_path": content_path, "unormaly_content_path": content_path}), 200)
@app.route('/boundingbox_record', methods=["POST"])
@cross_origin()
def record_bb():
data = request.get_json()
username = data['username']
add_line(API_result_path,['boundingbox', username])
return make_response(jsonify({}), 200)
@app.route('/all_result_list', methods=["POST"])
@cross_origin()
def get_res():
data = request.get_json()
CONTENT_PATH = os.path.normpath(data['content_path'])
VIS_METHOD = data['vis_method']
SETTING = data["setting"]
username = data["username"]
predicates = dict() # placeholder
results = dict()
imglist = dict()
gridlist = dict()
sys.path.append(CONTENT_PATH)
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
EPOCH_START = context.strategy.config["EPOCH_START"]
EPOCH_PERIOD = context.strategy.config["EPOCH_PERIOD"]
EPOCH_END = context.strategy.config["EPOCH_END"]
epoch_num = (EPOCH_END - EPOCH_START)// EPOCH_PERIOD + 1
for i in range(1, epoch_num+1, 1):
EPOCH = (i-1)*EPOCH_PERIOD + EPOCH_START
trustvis = initialize_backend(CONTENT_PATH)
# detect whether we have query before
fname = "Epoch" if trustvis.data_provider.mode == "normal" or trustvis.data_provider.mode == "abnormal" else "Iteration"
checkpoint_path = context.strategy.data_provider.checkpoint_path(EPOCH)
bgimg_path = os.path.join(checkpoint_path, "bgimg.png")
embedding_path = os.path.join(checkpoint_path, "embedding.npy")
grid_path = os.path.join(checkpoint_path, "grid.pkl")
if os.path.exists(bgimg_path) and os.path.exists(embedding_path) and os.path.exists(grid_path):
path = os.path.join(trustvis.data_provider.model_path, "{}_{}".format(fname, EPOCH))
result_path = os.path.join(path,"embedding.npy")
results[str(i)] = np.load(result_path).tolist()
with open(os.path.join(path, "grid.pkl"), "rb") as f:
grid = pickle.load(f)
gridlist[str(i)] = grid
else:
embedding_2d, grid, _, _, _, _, _, _, _, _, _, _, _ = update_epoch_projection(trustvis, EPOCH, predicates)
results[str(i)] = embedding_2d
gridlist[str(i)] = grid
# read background img
with open(bgimg_path, 'rb') as img_f:
img_stream = img_f.read()
img_stream = base64.b64encode(img_stream).decode()
imglist[str(i)] = 'data:image/png;base64,' + img_stream
# imglist[str(i)] = "http://{}{}".format(ip_adress, bgimg_path)
sys.path.remove(CONTENT_PATH)
del config
gc.collect()
add_line(API_result_path,['animation', username])
return make_response(jsonify({"results":results,"bgimgList":imglist, "grid": gridlist}), 200)
@app.route('/get_itertaion_structure', methods=["POST", "GET"])
@cross_origin()
def get_tree():
CONTENT_PATH = request.args.get("path")
res_json_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
with open(res_json_path,encoding='utf8')as fp:
json_data = json.load(fp)
return make_response(jsonify({"structure":json_data}), 200)
def check_port_inuse(port, host):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
s.connect((host, port))
return True
except socket.error:
return False
finally:
if s:
s.close()
if __name__ == "__main__":
import socket
hostname = socket.gethostname()
ip_address = socket.gethostbyname(hostname)
port = 5000
while check_port_inuse(port, ip_address):
port = port + 1
app.run(host=ip_address, port=int(port))