import os
import io
import torch
import json
import base64
import gradio as gr
import numpy as np
from pathlib import Path
from PIL import Image
from plots import get_pre_define_colors
from utils.load_model import load_xclip
from utils.predict import xclip_pred
#! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Not at Huggingface demo, load model to main process.")
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
print(f"Device: {DEVICE}")
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
IMAGES_FOLDER = "data/images"
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r"))
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST]
ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10'])
SACHIT_COLOR = "#ADD8E6"
# CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r"))
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r'))
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12))
# --- Image related functions ---
def img_to_base64(img):
img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img
buffered = io.BytesIO()
img_pil.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode()
def create_blank_image(width=500, height=500, color=(255, 255, 255)):
"""Create a blank image of the given size and color."""
return np.array(Image.new("RGB", (width, height), color))
# Convert RGB colors to hex
def rgb_to_hex(rgb):
return f"#{''.join(f'{x:02x}' for x in rgb)}"
def load_part_images(file_name: str) -> dict:
part_images = {}
# start_time = time.time()
for part_name in ORDERED_PARTS:
base_name = Path(file_name).stem
part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg")
if not Path(part_image_path).exists():
continue
image = np.array(Image.open(part_image_path))
part_images[part_name] = img_to_base64(image)
# print(f"Time cost to load 12 images: {time.time() - start_time}")
# This takes less than 0.01 seconds. So the loading time is not the bottleneck.
return part_images
def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))):
"""
The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name'
descriptions: {part_name1: desc_1, part_name2: desc_2, ...}
pred_scores: {part_name1: score_1, part_name2: score_2, ...}
file_name: str
"""
descriptions = result_dict['descriptions']
image_name = result_dict['file_name']
part_images = PART_IMAGES_DICT[image_name]
MAX_LENGTH = 50
exp_length = 400
fontsize = 15
# Start the SVG inside a div
svg_parts = [f'
',
"", "
"))
# Join everything into a single string
html = "".join(svg_parts)
return html
def generate_sachit_explanations(result_dict:dict):
descriptions = result_dict['descriptions']
scores = result_dict['scores']
MAX_LENGTH = 50
exp_length = 400
fontsize = 15
descriptions = zip(scores, descriptions)
descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True)
# Start the SVG inside a div
svg_parts = [f'
',
"", "
"))
# Join everything into a single string
html = "".join(svg_parts)
return html
# --- Constants created by the functions above ---
BLANK_OVERLAY = img_to_base64(create_blank_image())
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)}
blank_image = np.array(Image.open('data/images/final.png').convert('RGB'))
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST}
# --- Gradio Functions ---
def update_selected_image(event: gr.SelectData):
image_height = 400
index = event.index
image_name = IMAGE_FILE_LIST[index]
current_image.state = image_name
org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB')
img_base64 = f"""
"""
gt_label = IMAGE2GT[image_name]
gt_class.state = gt_label
# --- for initial value only ---
out_dict = xclip_pred(new_desc=None,
new_part_mask=None,
new_class=None,
org_desc=XCLIP_DESC_PATH,
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
model=XCLIP,
owlvit_processor=OWLVIT_PRECESSOR,
device=DEVICE,
image_name=current_image.state,
cub_embeds=CUB_DESC_EMBEDS,
cub_idx2name=CUB_IDX2NAME,
descriptors=XCLIP_DESC)
xclip_label = out_dict['pred_class']
clip_pred_scores = out_dict['pred_score']
xclip_part_scores = out_dict['pred_desc_scores']
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12)))
# --- end of intial value ---
xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red"
xclip_pred_markdown = f"""
### {xclip_label} {clip_pred_scores:.4f}
"""
gt_label = f"""
## {gt_label}
"""
current_predicted_class.state = xclip_label
# Populate the textbox with current descriptions
custom_class_name = "class name: custom"
descs = XCLIP_DESC[xclip_label]
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
descs = {k: descs[k] for k in ORDERED_PARTS}
custom_text = [custom_class_name] + list(descs.values())
descriptions = ";\n".join(custom_text)
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
textbox = gr.Textbox(value=descriptions,
lines=12,
visible=True,
label="XCLIP descriptions",
interactive=True,
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
show_label=False)
# modified_exp = gr.HTML().update(value="", visible=True)
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
def on_edit_button_click_xclip():
# empty_exp = gr.HTML.update(visible=False)
empty_exp = gr.HTML(visible=False)
# Populate the textbox with current descriptions
descs = XCLIP_DESC[current_predicted_class.state]
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
descs = {k: descs[k] for k in ORDERED_PARTS}
custom_text = ["class name: custom"] + list(descs.values())
descriptions = ";\n".join(custom_text)
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
textbox = gr.Textbox(value=descriptions,
lines=12,
visible=True,
label="XCLIP descriptions",
interactive=True,
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
show_label=False)
return textbox, empty_exp
def convert_input_text_to_xclip_format(textbox_input: str):
# Split the descriptions by newline to get individual descriptions for each part
descriptions_list = textbox_input.split(";\n")
# the first line should be "class name: xxx"
class_name_line = descriptions_list[0]
new_class_name = class_name_line.split(":")[1].strip()
descriptions_list = descriptions_list[1:]
# construct descripion dict with part name as key
descriptions_dict = {}
for desc in descriptions_list:
if desc.strip() == "":
continue
part_name, _ = desc.split(":")
descriptions_dict[part_name.strip()] = desc
# fill with empty string if the part is not in the descriptions
part_mask = {}
for part in ORDERED_PARTS:
if part not in descriptions_dict:
descriptions_dict[part] = ""
part_mask[part] = 0
else:
part_mask[part] = 1
return descriptions_dict, part_mask, new_class_name
def on_predict_button_click_xclip(textbox_input: str):
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
# Get the new predictions and explanations
out_dict = xclip_pred(new_desc=descriptions_dict,
new_part_mask=part_mask,
new_class=new_class_name,
org_desc=XCLIP_DESC_PATH,
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
model=XCLIP,
owlvit_processor=OWLVIT_PRECESSOR,
device=DEVICE,
image_name=current_image.state,
cub_embeds=CUB_DESC_EMBEDS,
cub_idx2name=CUB_IDX2NAME,
descriptors=XCLIP_DESC)
xclip_label = out_dict['pred_class']
xclip_pred_score = out_dict['pred_score']
xclip_part_scores = out_dict['pred_desc_scores']
custom_label = out_dict['modified_class']
custom_pred_score = out_dict['modified_score']
custom_part_scores = out_dict['modified_desc_scores']
# construct a result dict to generate xclip explanations
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask)
modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state}
modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask)
xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red"
xclip_pred_markdown = f"""
### {xclip_label} {xclip_pred_score:.4f}
"""
custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red"
custom_pred_markdown = f"""
### {custom_label} {custom_pred_score:.4f}
"""
# textbox = gr.Textbox.update(visible=False)
textbox = gr.Textbox(visible=False)
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
# modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
modified_exp = gr.HTML(value=modified_explanation, visible=True)
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
custom_css = """
html, body {
margin: 0;
padding: 0;
}
#container {
position: relative;
width: 400px;
height: 400px;
border: 1px solid #000;
margin: 0 auto; /* This will center the container horizontally */
}
#canvas {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
object-fit: cover;
}
"""
# Define the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
current_image = gr.State("")
current_predicted_class = gr.State("")
gt_class = gr.State("")
with gr.Column():
title_text = gr.Markdown("# Demo | A classifier with Part-based Explainable and Editable Bottleneck (PEEB)")
gr.Markdown("PEEB is an image classifier, here for birds, pre-trained on Bird-11K and finetuned on CUB-200 (see our [NAACL 2024 paper](https://arxiv.org/abs/2403.05297) and [code](https://github.com/anguyen8/peeb/tree/inspect_ddp)).\n This **interactive** demo shows how to run PEEB on an existing image and how to **edit** a class' textual description to directly change the classifier to detect one new bird species (without any re-training).")
gr.Markdown(
"""
### Steps:
1. **Select an image**. Then, PEEB will show its grounded explanations and the top-1 predicted label with associated `softmax` confidence score.
2. **Hover mouse over text descriptors** to see the corresponding region used to match to each text descriptor.
3. **Edit the text under [Extra class]()** which correspond to one extra, new class (i.e. 200+1 = `201`). Further editing will overwrite this class' descriptors.
4. **Click on [Predict]()** to see the grounded explanations and the top-1 label for the newly modified CUB-201 classifier.
"""
)
# display the gallery of images
with gr.Column():
gr.Markdown("## Select an image to start!")
image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250)
gr.Markdown("### Extra-class descriptors: \n The first row should be `class name: {some name};`, the name of your 201th class. \n For the 12 part descriptors, please use `;` to separate the descriptions for each part, and use the format `{part name}: {descriptions}`.")
gr.Markdown("**Note:** you can delete a row for any given part (e.g. `nape`) and that part will be removed from all 201 classes in the classifier. For example, you can edit PEEB into a classifier that only identifies birds using 5 parts by deleting all rows corresponding to the other 7 parts.")
with gr.Row():
with gr.Column():
image_label = gr.Markdown("### Class Name")
org_image = gr.HTML()
with gr.Column():
with gr.Row():
# xclip_predict_button = gr.Button(label="Predict", value="Predict")
xclip_predict_button = gr.Button(value="Predict")
xclip_pred_label = gr.Markdown("### Top-1 class:")
xclip_explanation = gr.HTML()
with gr.Column():
# xclip_edit_button = gr.Button(label="Edit", value="Reset Extra-class descriptors")
xclip_edit_button = gr.Button(value="Reset Descriptions")
custom_pred_label = gr.Markdown(
"### Extra class:"
)
xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False)
# ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500)
custom_explanation = gr.HTML()
gr.HTML(" ")
image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox])
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
demo.launch()