Spaces:
Sleeping
Sleeping
maximuspowers
commited on
Commit
•
d77f6b0
1
Parent(s):
d00edf9
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
import json
|
4 |
+
import gradio as gr
|
5 |
+
from nltk.corpus import words
|
6 |
+
import nltk
|
7 |
+
|
8 |
+
|
9 |
+
# load files w embeddings, attention scores, and tokens
|
10 |
+
vocab_embeddings = np.load('vocab_embeddings.npy')
|
11 |
+
with open('vocab_attention_scores.json', 'r') as f:
|
12 |
+
vocab_attention_scores = json.load(f)
|
13 |
+
with open('vocab_tokens.json', 'r') as f:
|
14 |
+
vocab_tokens = json.load(f)
|
15 |
+
|
16 |
+
# attention scores to numpy arrs
|
17 |
+
b_gen_attention = np.array([score['B-GEN'] for score in vocab_attention_scores])
|
18 |
+
i_gen_attention = np.array([score['I-GEN'] for score in vocab_attention_scores])
|
19 |
+
b_unfair_attention = np.array([score['B-UNFAIR'] for score in vocab_attention_scores])
|
20 |
+
i_unfair_attention = np.array([score['I-UNFAIR'] for score in vocab_attention_scores])
|
21 |
+
b_stereo_attention = np.array([score['B-STEREO'] for score in vocab_attention_scores])
|
22 |
+
i_stereo_attention = np.array([score['I-STEREO'] for score in vocab_attention_scores])
|
23 |
+
o_attention = np.array([score['O'] for score in vocab_attention_scores]) # Use actual O scores
|
24 |
+
|
25 |
+
# remove non-dict english words, but keep subwords ##
|
26 |
+
nltk.download('words')
|
27 |
+
english_words = set(words.words())
|
28 |
+
|
29 |
+
filtered_indices = [i for i, token in enumerate(vocab_tokens) if token in english_words or token.startswith("##")]
|
30 |
+
filtered_tokens = [vocab_tokens[i] for i in filtered_indices]
|
31 |
+
|
32 |
+
b_gen_attention_filtered = b_gen_attention[filtered_indices]
|
33 |
+
i_gen_attention_filtered = i_gen_attention[filtered_indices]
|
34 |
+
b_unfair_attention_filtered = b_unfair_attention[filtered_indices]
|
35 |
+
i_unfair_attention_filtered = i_unfair_attention[filtered_indices]
|
36 |
+
b_stereo_attention_filtered = b_stereo_attention[filtered_indices]
|
37 |
+
i_stereo_attention_filtered = i_stereo_attention[filtered_indices]
|
38 |
+
o_attention_filtered = o_attention[filtered_indices]
|
39 |
+
|
40 |
+
# plot top 500 O tokens for comparison
|
41 |
+
top_500_o_indices = np.argsort(o_attention_filtered)[-500:]
|
42 |
+
top_500_o_tokens = [filtered_tokens[i] for i in top_500_o_indices]
|
43 |
+
o_attention_filtered_top_500 = o_attention_filtered[top_500_o_indices]
|
44 |
+
|
45 |
+
# tool tip for tokens
|
46 |
+
def create_hover_text(tokens, b_gen, i_gen, b_unfair, i_unfair, b_stereo, i_stereo, o_val):
|
47 |
+
hover_text = []
|
48 |
+
for i in range(len(tokens)):
|
49 |
+
hover_text.append(
|
50 |
+
f"Token: {tokens[i]}<br>"
|
51 |
+
f"B-GEN: {b_gen[i]:.3f}, I-GEN: {i_gen[i]:.3f}<br>"
|
52 |
+
f"B-UNFAIR: {b_unfair[i]:.3f}, I-UNFAIR: {i_unfair[i]:.3f}<br>"
|
53 |
+
f"B-STEREO: {b_stereo[i]:.3f}, I-STEREO: {i_stereo[i]:.3f}<br>"
|
54 |
+
f"O: {o_val[i]:.3f}"
|
55 |
+
)
|
56 |
+
return hover_text
|
57 |
+
|
58 |
+
# ploting top 100 tokens for each entity
|
59 |
+
def select_top_100(*data_arrays):
|
60 |
+
indices_list = []
|
61 |
+
for data in data_arrays:
|
62 |
+
if data is not None:
|
63 |
+
top_indices = np.argsort(data)[-100:]
|
64 |
+
indices_list.append(top_indices)
|
65 |
+
|
66 |
+
combined_indices = np.unique(np.concatenate(indices_list))
|
67 |
+
|
68 |
+
# filter based on combined indices
|
69 |
+
filtered_data = [data[combined_indices] if data is not None else None for data in data_arrays]
|
70 |
+
tokens_filtered = [filtered_tokens[i] for i in combined_indices]
|
71 |
+
|
72 |
+
return (*filtered_data, tokens_filtered)
|
73 |
+
|
74 |
+
# plots for 1 2 and 3 D
|
75 |
+
def create_plot(selected_dimensions):
|
76 |
+
# plot data
|
77 |
+
attention_map = {
|
78 |
+
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
|
79 |
+
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
|
80 |
+
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
|
81 |
+
}
|
82 |
+
|
83 |
+
# init x, y, z so they can be moved around
|
84 |
+
x_data, y_data, z_data = None, None, None
|
85 |
+
|
86 |
+
# use selected dimentsions to order dimensions
|
87 |
+
if len(selected_dimensions) > 0:
|
88 |
+
x_data = attention_map[selected_dimensions[0]]
|
89 |
+
if len(selected_dimensions) > 1:
|
90 |
+
y_data = attention_map[selected_dimensions[1]]
|
91 |
+
if len(selected_dimensions) > 2:
|
92 |
+
z_data = attention_map[selected_dimensions[2]]
|
93 |
+
|
94 |
+
# select top 100 dps for each selected dimension
|
95 |
+
x_data, y_data, z_data, tokens_filtered = select_top_100(x_data, y_data, z_data)
|
96 |
+
|
97 |
+
# filter the O tokens using the same dimensions
|
98 |
+
o_x = attention_map[selected_dimensions[0]][top_500_o_indices]
|
99 |
+
if len(selected_dimensions) > 1:
|
100 |
+
o_y = attention_map[selected_dimensions[1]][top_500_o_indices]
|
101 |
+
else:
|
102 |
+
o_y = np.zeros_like(o_x)
|
103 |
+
if len(selected_dimensions) > 2:
|
104 |
+
o_z = attention_map[selected_dimensions[2]][top_500_o_indices]
|
105 |
+
else:
|
106 |
+
o_z = np.zeros_like(o_x)
|
107 |
+
|
108 |
+
# hover text for GUS tokens
|
109 |
+
classified_hover_text = create_hover_text(
|
110 |
+
tokens_filtered,
|
111 |
+
b_gen_attention_filtered, i_gen_attention_filtered,
|
112 |
+
b_unfair_attention_filtered, i_unfair_attention_filtered,
|
113 |
+
b_stereo_attention_filtered, i_stereo_attention_filtered,
|
114 |
+
o_attention_filtered
|
115 |
+
)
|
116 |
+
|
117 |
+
# hover text for O tokens
|
118 |
+
o_hover_text = create_hover_text(
|
119 |
+
top_500_o_tokens,
|
120 |
+
b_gen_attention_filtered[top_500_o_indices], i_gen_attention_filtered[top_500_o_indices],
|
121 |
+
b_unfair_attention_filtered[top_500_o_indices], i_unfair_attention_filtered[top_500_o_indices],
|
122 |
+
b_stereo_attention_filtered[top_500_o_indices], i_stereo_attention_filtered[top_500_o_indices],
|
123 |
+
o_attention_filtered_top_500
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
# plot
|
128 |
+
fig = go.Figure()
|
129 |
+
|
130 |
+
if x_data is not None and y_data is not None and z_data is not None:
|
131 |
+
# 3d scatter plot
|
132 |
+
fig.add_trace(go.Scatter3d(
|
133 |
+
x=x_data,
|
134 |
+
y=y_data,
|
135 |
+
z=z_data,
|
136 |
+
mode='markers',
|
137 |
+
marker=dict(
|
138 |
+
size=6,
|
139 |
+
color=x_data, # color based on the x-axis data
|
140 |
+
colorscale='Viridis',
|
141 |
+
opacity=0.85,
|
142 |
+
),
|
143 |
+
text=classified_hover_text,
|
144 |
+
hoverinfo='text',
|
145 |
+
name='Classified Tokens'
|
146 |
+
))
|
147 |
+
# add top 500 O tags to the plot too
|
148 |
+
fig.add_trace(go.Scatter3d(
|
149 |
+
x=o_x,
|
150 |
+
y=o_y,
|
151 |
+
z=o_z,
|
152 |
+
mode='markers',
|
153 |
+
marker=dict(
|
154 |
+
size=6,
|
155 |
+
color='grey',
|
156 |
+
opacity=0.5,
|
157 |
+
),
|
158 |
+
text=o_hover_text,
|
159 |
+
hoverinfo='text',
|
160 |
+
name='O Tokens'
|
161 |
+
))
|
162 |
+
elif x_data is not None and y_data is not None:
|
163 |
+
# 2d scatter plot
|
164 |
+
fig.add_trace(go.Scatter(
|
165 |
+
x=x_data,
|
166 |
+
y=y_data,
|
167 |
+
mode='markers',
|
168 |
+
marker=dict(
|
169 |
+
size=6,
|
170 |
+
color=x_data, # color based on the x-axis data
|
171 |
+
colorscale='Viridis',
|
172 |
+
opacity=0.85,
|
173 |
+
),
|
174 |
+
text=classified_hover_text,
|
175 |
+
hoverinfo='text',
|
176 |
+
name='Classified Tokens'
|
177 |
+
))
|
178 |
+
# add top 500 O tags to the plot too
|
179 |
+
fig.add_trace(go.Scatter(
|
180 |
+
x=o_x,
|
181 |
+
y=o_y,
|
182 |
+
mode='markers',
|
183 |
+
marker=dict(
|
184 |
+
size=6,
|
185 |
+
color='grey',
|
186 |
+
opacity=0.5,
|
187 |
+
),
|
188 |
+
text=o_hover_text,
|
189 |
+
hoverinfo='text',
|
190 |
+
name='O Tokens'
|
191 |
+
))
|
192 |
+
elif x_data is not None:
|
193 |
+
# 1D scatter plot
|
194 |
+
fig.add_trace(go.Scatter(
|
195 |
+
x=x_data,
|
196 |
+
y=np.zeros_like(x_data),
|
197 |
+
mode='markers',
|
198 |
+
marker=dict(
|
199 |
+
size=6,
|
200 |
+
color=x_data,
|
201 |
+
colorscale='Viridis',
|
202 |
+
opacity=0.85,
|
203 |
+
),
|
204 |
+
text=classified_hover_text,
|
205 |
+
hoverinfo='text',
|
206 |
+
name='GUS Tokens'
|
207 |
+
))
|
208 |
+
fig.add_trace(go.Scatter(
|
209 |
+
x=o_x,
|
210 |
+
y=np.zeros_like(o_x),
|
211 |
+
mode='markers',
|
212 |
+
marker=dict(
|
213 |
+
size=6,
|
214 |
+
color='grey',
|
215 |
+
opacity=0.5,
|
216 |
+
),
|
217 |
+
text=o_hover_text,
|
218 |
+
hoverinfo='text',
|
219 |
+
name='O Tokens'
|
220 |
+
))
|
221 |
+
|
222 |
+
# update layout dynamically
|
223 |
+
if x_data is not None and y_data is not None and z_data is not None:
|
224 |
+
# 3D
|
225 |
+
fig.update_layout(
|
226 |
+
title="GUS-Net Entity Attentions Visualization",
|
227 |
+
scene=dict(
|
228 |
+
xaxis=dict(title=f"{selected_dimensions[0]} Attention"),
|
229 |
+
yaxis=dict(title=f"{selected_dimensions[1]} Attention"),
|
230 |
+
zaxis=dict(title=f"{selected_dimensions[2]} Attention"),
|
231 |
+
),
|
232 |
+
margin=dict(l=0, r=0, b=0, t=40),
|
233 |
+
)
|
234 |
+
elif x_data is not None and y_data is not None:
|
235 |
+
# 2D
|
236 |
+
fig.update_layout(
|
237 |
+
title="GUS-Net Entity Attentions Visualization",
|
238 |
+
xaxis_title=f"{selected_dimensions[0]} Attention",
|
239 |
+
yaxis_title=f"{selected_dimensions[1]} Attention",
|
240 |
+
margin=dict(l=0, r=0, b=0, t=40),
|
241 |
+
)
|
242 |
+
elif x_data is not None:
|
243 |
+
# 1D
|
244 |
+
fig.update_layout(
|
245 |
+
title="GUS-Net Entity Attentions Visualization",
|
246 |
+
xaxis_title=f"{selected_dimensions[0]} Attention",
|
247 |
+
margin=dict(l=0, r=0, b=0, t=40),
|
248 |
+
)
|
249 |
+
|
250 |
+
return fig
|
251 |
+
|
252 |
+
def get_top_tokens_for_entities(selected_dimensions):
|
253 |
+
entity_map = {
|
254 |
+
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
|
255 |
+
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
|
256 |
+
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
|
257 |
+
}
|
258 |
+
|
259 |
+
top_tokens_info = {}
|
260 |
+
for dimension in selected_dimensions:
|
261 |
+
if dimension in entity_map:
|
262 |
+
attention_scores = entity_map[dimension]
|
263 |
+
top_indices = np.argsort(attention_scores)[-10:] # top 10 tokens
|
264 |
+
top_tokens = [filtered_tokens[i] for i in top_indices]
|
265 |
+
top_scores = attention_scores[top_indices]
|
266 |
+
top_tokens_info[dimension] = list(zip(top_tokens, top_scores))
|
267 |
+
|
268 |
+
return top_tokens_info
|
269 |
+
|
270 |
+
def update_gradio(selected_dimensions):
|
271 |
+
fig = create_plot(selected_dimensions)
|
272 |
+
|
273 |
+
top_tokens_info = get_top_tokens_for_entities(selected_dimensions)
|
274 |
+
|
275 |
+
formatted_top_tokens = ""
|
276 |
+
for entity, tokens_info in top_tokens_info.items():
|
277 |
+
formatted_top_tokens += f"\nTop tokens for {entity}:\n"
|
278 |
+
for token, score in tokens_info:
|
279 |
+
formatted_top_tokens += f"Token: {token}, Attention Score: {score:.3f}\n"
|
280 |
+
|
281 |
+
return fig, formatted_top_tokens
|
282 |
+
|
283 |
+
|
284 |
+
def render_gradio_interface():
|
285 |
+
with gr.Blocks() as interface:
|
286 |
+
with gr.Column():
|
287 |
+
dimensions_input = gr.CheckboxGroup(
|
288 |
+
choices=["Generalization", "Unfairness", "Stereotype"],
|
289 |
+
label="Select Dimensions to Plot",
|
290 |
+
value=["Generalization", "Unfairness", "Stereotype"] # defaults to 3D
|
291 |
+
)
|
292 |
+
|
293 |
+
plot_output = gr.Plot(label="Token Attention Visualization")
|
294 |
+
top_tokens_output = gr.Textbox(label="Top Tokens for Each Entity Class", lines=10)
|
295 |
+
|
296 |
+
dimensions_input.change(
|
297 |
+
fn=update_gradio,
|
298 |
+
inputs=[dimensions_input],
|
299 |
+
outputs=[plot_output, top_tokens_output]
|
300 |
+
)
|
301 |
+
|
302 |
+
interface.load(
|
303 |
+
fn=lambda: update_gradio(["Generalization", "Unfairness", "Stereotype"]),
|
304 |
+
inputs=None,
|
305 |
+
outputs=[plot_output, top_tokens_output]
|
306 |
+
)
|
307 |
+
|
308 |
+
return interface
|
309 |
+
|
310 |
+
interface = render_gradio_interface()
|
311 |
+
interface.launch()
|