Spaces:
Runtime error
Runtime error
fix: updating explanation component
Browse files- backend/controller.py +3 -3
- explanation/markup.py +26 -22
- main.py +8 -6
backend/controller.py
CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
|
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
9 |
-
from explanation import
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
@@ -28,9 +28,9 @@ def interference(
|
|
28 |
if xai_selection in ("SHAP", "Visualizer"):
|
29 |
match xai_selection.lower():
|
30 |
case "shap":
|
31 |
-
xai =
|
32 |
case "visualizer":
|
33 |
-
xai =
|
34 |
case _:
|
35 |
# use Gradio warning to display error message
|
36 |
gr.Warning(f"""
|
|
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
9 |
+
from explanation import interpret_shap as sint, visualize as viz
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
|
|
28 |
if xai_selection in ("SHAP", "Visualizer"):
|
29 |
match xai_selection.lower():
|
30 |
case "shap":
|
31 |
+
xai = sint
|
32 |
case "visualizer":
|
33 |
+
xai = viz
|
34 |
case _:
|
35 |
# use Gradio warning to display error message
|
36 |
gr.Warning(f"""
|
explanation/markup.py
CHANGED
@@ -9,7 +9,7 @@ from utils import formatting as fmt
|
|
9 |
|
10 |
|
11 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
12 |
-
|
13 |
|
14 |
# Flatten the explanations values
|
15 |
if variant == "shap":
|
@@ -21,39 +21,43 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
21 |
|
22 |
# Separate the threshold calculation for negative and positive values
|
23 |
if variant == "visualizer":
|
24 |
-
|
|
|
|
|
25 |
else:
|
26 |
-
neg_thresholds = np.linspace(
|
27 |
-
1
|
28 |
-
]
|
29 |
-
|
30 |
-
|
31 |
|
32 |
marked_text = []
|
33 |
|
34 |
# Function to determine the bucket for a given value
|
35 |
for text, value in zip(input_text, text_values):
|
36 |
-
bucket =
|
37 |
-
for i, threshold in
|
38 |
-
if value
|
39 |
bucket = i
|
40 |
marked_text.append((text, str(bucket)))
|
41 |
|
|
|
|
|
42 |
return marked_text
|
43 |
|
44 |
|
45 |
def color_codes():
|
46 |
return {
|
47 |
-
# 1-5: Strong Light
|
48 |
-
"
|
49 |
-
"
|
50 |
-
"3": "#
|
51 |
-
"
|
52 |
-
"
|
53 |
-
|
54 |
-
"
|
55 |
-
"
|
56 |
-
"
|
57 |
-
"
|
58 |
-
"
|
59 |
}
|
|
|
9 |
|
10 |
|
11 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
12 |
+
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
13 |
|
14 |
# Flatten the explanations values
|
15 |
if variant == "shap":
|
|
|
21 |
|
22 |
# Separate the threshold calculation for negative and positive values
|
23 |
if variant == "visualizer":
|
24 |
+
neg_thresholds = np.linspace(
|
25 |
+
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
26 |
+
)[1:]
|
27 |
else:
|
28 |
+
neg_thresholds = np.linspace(
|
29 |
+
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
30 |
+
)[1:]
|
31 |
+
pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
|
32 |
+
thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
|
33 |
|
34 |
marked_text = []
|
35 |
|
36 |
# Function to determine the bucket for a given value
|
37 |
for text, value in zip(input_text, text_values):
|
38 |
+
bucket = "-5"
|
39 |
+
for i, threshold in zip(bucket_tags, thresholds):
|
40 |
+
if value >= threshold:
|
41 |
bucket = i
|
42 |
marked_text.append((text, str(bucket)))
|
43 |
|
44 |
+
print(thresholds)
|
45 |
+
print(marked_text)
|
46 |
return marked_text
|
47 |
|
48 |
|
49 |
def color_codes():
|
50 |
return {
|
51 |
+
# 1-5: Strong Light Sky Blue to Lighter Sky Blue
|
52 |
+
"-5": "#3251a8", # Strong Light Sky Blue
|
53 |
+
"-4": "#5A7FB2", # Slightly Lighter Sky Blue
|
54 |
+
"-3": "#8198BC", # Intermediate Sky Blue
|
55 |
+
"-2": "#A8B1C6", # Light Sky Blue
|
56 |
+
"-1": "#E6F0FF", # Very Light Sky Blue
|
57 |
+
"0": "#FFFFFF", # White
|
58 |
+
"+1": "#FFE6F0", # Lighter Pink
|
59 |
+
"+2": "#DF8CA3", # Slightly Stronger Pink
|
60 |
+
"+3": "#D7708E", # Intermediate Pink
|
61 |
+
"+4": "#CF5480", # Deep Pink
|
62 |
+
"+5": "#A83273", # Strong Magenta
|
63 |
}
|
main.py
CHANGED
@@ -75,7 +75,7 @@ with gr.Blocks(
|
|
75 |
|
76 |
""")
|
77 |
# row with columns for the different settings
|
78 |
-
with gr.Row(equal_height=True
|
79 |
# column that takes up 3/5 of the row
|
80 |
with gr.Column(scale=3):
|
81 |
# textbox to enter the system prompt
|
@@ -108,13 +108,15 @@ with gr.Blocks(
|
|
108 |
# accordion to display the normalized input explanation
|
109 |
with gr.Accordion(label="Input Explanation", open=False):
|
110 |
gr.Markdown("""
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
XAI method you selected.
|
115 |
""")
|
116 |
xai_text = gr.HighlightedText(
|
117 |
-
color_map=coloring,
|
|
|
|
|
|
|
118 |
)
|
119 |
# out of the box chatbot component
|
120 |
# see documentation: https://www.gradio.app/docs/chatbot
|
|
|
75 |
|
76 |
""")
|
77 |
# row with columns for the different settings
|
78 |
+
with gr.Row(equal_height=True):
|
79 |
# column that takes up 3/5 of the row
|
80 |
with gr.Column(scale=3):
|
81 |
# textbox to enter the system prompt
|
|
|
108 |
# accordion to display the normalized input explanation
|
109 |
with gr.Accordion(label="Input Explanation", open=False):
|
110 |
gr.Markdown("""
|
111 |
+
The explanations are based on 10 buckets that range between the
|
112 |
+
lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
|
113 |
+
**The legend show the color for each bucket.**
|
|
|
114 |
""")
|
115 |
xai_text = gr.HighlightedText(
|
116 |
+
color_map=coloring,
|
117 |
+
label="Input Explanation",
|
118 |
+
show_legend=True,
|
119 |
+
show_label=False,
|
120 |
)
|
121 |
# out of the box chatbot component
|
122 |
# see documentation: https://www.gradio.app/docs/chatbot
|