added attention vizualization and qa model
Browse files- app.py +31 -8
- attention_viz.py +227 -0
- custom_bart/bart_attention.py +1 -1
app.py
CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
|
4 |
from inference import RelationsInference
|
5 |
-
from
|
|
|
6 |
|
7 |
#prep
|
8 |
import nltk
|
@@ -16,28 +17,50 @@ examples = [["What's the meaning of life?", "eli5", "constraint"],
|
|
16 |
["boat, water, bird", "commongen", "constraint"],
|
17 |
["What flows under a bridge?", "commonsense_qa", "constraint"]]
|
18 |
|
19 |
-
|
20 |
model_path='MrVicente/commonsense_bart_commongen',
|
21 |
kg_type=KGType.CONCEPTNET,
|
22 |
model_type=Model_Type.RELATIONS,
|
23 |
max_length=32
|
24 |
)
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
#############################
|
27 |
# Helper
|
28 |
#############################
|
29 |
|
30 |
def infer_bart(context, task_type, decoding_type_str):
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return response[0]
|
33 |
|
34 |
|
35 |
-
def plot_attention(layer, head):
|
36 |
fig = plt.figure()
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
return fig
|
42 |
|
43 |
|
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
|
4 |
from inference import RelationsInference
|
5 |
+
from attention_viz import AttentionVisualizer
|
6 |
+
from utils import KGType, Model_Type, Data_Type
|
7 |
|
8 |
#prep
|
9 |
import nltk
|
|
|
17 |
["boat, water, bird", "commongen", "constraint"],
|
18 |
["What flows under a bridge?", "commonsense_qa", "constraint"]]
|
19 |
|
20 |
+
commongen_bart = RelationsInference(
|
21 |
model_path='MrVicente/commonsense_bart_commongen',
|
22 |
kg_type=KGType.CONCEPTNET,
|
23 |
model_type=Model_Type.RELATIONS,
|
24 |
max_length=32
|
25 |
)
|
26 |
|
27 |
+
qa_bart = RelationsInference(
|
28 |
+
model_path='MrVicente/commonsense_bart_absqa',
|
29 |
+
kg_type=KGType.CONCEPTNET,
|
30 |
+
model_type=Model_Type.RELATIONS,
|
31 |
+
max_length=128
|
32 |
+
)
|
33 |
+
att_viz = AttentionVisualizer(device='cpu')
|
34 |
#############################
|
35 |
# Helper
|
36 |
#############################
|
37 |
|
38 |
def infer_bart(context, task_type, decoding_type_str):
|
39 |
+
if Data_Type(task_type) == Data_Type.COMMONGEN:
|
40 |
+
if decoding_type_str =='default':
|
41 |
+
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
|
42 |
+
else:
|
43 |
+
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True)
|
44 |
+
elif Data_Type(task_type) == Data_Type.ELI5:
|
45 |
+
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError()
|
48 |
return response[0]
|
49 |
|
50 |
|
51 |
+
def plot_attention(context, task_type, layer, head):
|
52 |
fig = plt.figure()
|
53 |
+
if Data_Type(task_type) == Data_Type.COMMONGEN:
|
54 |
+
model = commongen_bart
|
55 |
+
elif Data_Type(task_type) == Data_Type.ELI5:
|
56 |
+
model = qa_bart
|
57 |
+
else:
|
58 |
+
raise NotImplementedError()
|
59 |
+
response, examples, relations = model.prepare_context_for_visualization(context)
|
60 |
+
att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
|
61 |
+
examples,
|
62 |
+
layer, head,
|
63 |
+
relations)
|
64 |
return fig
|
65 |
|
66 |
|
attention_viz.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
|
7 |
+
# Remote modules
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Local modules
|
13 |
+
|
14 |
+
#############################
|
15 |
+
# Constants
|
16 |
+
#############################
|
17 |
+
|
18 |
+
class AttentionVisualizer:
|
19 |
+
def __init__(self, device):
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
def visualize_token2token_scores(self, all_tokens,
|
23 |
+
scores_mat,
|
24 |
+
useful_indeces,
|
25 |
+
x_label_name='Head',
|
26 |
+
apply_normalization=True):
|
27 |
+
fig = plt.figure(figsize=(20, 20))
|
28 |
+
|
29 |
+
all_tokens = np.array(all_tokens)[useful_indeces]
|
30 |
+
for idx, scores in enumerate(scores_mat):
|
31 |
+
if apply_normalization:
|
32 |
+
scores = torch.from_numpy(scores)
|
33 |
+
shape = scores.shape
|
34 |
+
scores = scores.reshape((shape[0],shape[1], 1))
|
35 |
+
scores = torch.linalg.norm(scores, dim=2)
|
36 |
+
scores_np = np.array(scores)
|
37 |
+
scores_np = scores_np[useful_indeces, :]
|
38 |
+
scores_np = scores_np[:, useful_indeces]
|
39 |
+
ax = fig.add_subplot(4, 4, idx + 1)
|
40 |
+
# append the attention weights
|
41 |
+
im = ax.imshow(scores_np, cmap='viridis')
|
42 |
+
|
43 |
+
fontdict = {'fontsize': 10}
|
44 |
+
|
45 |
+
ax.set_xticks(range(len(all_tokens)))
|
46 |
+
ax.set_yticks(range(len(all_tokens)))
|
47 |
+
|
48 |
+
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
|
49 |
+
ax.set_yticklabels(all_tokens, fontdict=fontdict)
|
50 |
+
ax.set_xlabel('{} {}'.format(x_label_name, idx + 1))
|
51 |
+
|
52 |
+
fig.colorbar(im, fraction=0.046, pad=0.04)
|
53 |
+
plt.tight_layout()
|
54 |
+
plt.show()
|
55 |
+
|
56 |
+
def visualize_matrix(self,
|
57 |
+
scores_mat,
|
58 |
+
label_name='heads_layers'):
|
59 |
+
_fig = plt.figure(figsize=(20, 20))
|
60 |
+
scores_np = np.array(scores_mat)
|
61 |
+
fig, ax = plt.subplots()
|
62 |
+
im = ax.imshow(scores_np, cmap='viridis')
|
63 |
+
|
64 |
+
fontdict = {'fontsize': 10}
|
65 |
+
|
66 |
+
ax.set_xticks(range(len(scores_mat[0])))
|
67 |
+
ax.set_yticks(range(len(scores_mat)))
|
68 |
+
|
69 |
+
x_labels = [f'head-{i}' for i in range(1, len(scores_mat[0])+1)]
|
70 |
+
y_labels = [f'layer-{i}' for i in range(1, len(scores_mat) + 1)]
|
71 |
+
|
72 |
+
ax.set_xticklabels(x_labels, fontdict=fontdict, rotation=90)
|
73 |
+
ax.set_yticklabels(y_labels, fontdict=fontdict)
|
74 |
+
ax.set_xlabel('{}'.format(label_name))
|
75 |
+
|
76 |
+
fig.colorbar(im, fraction=0.046, pad=0.04)
|
77 |
+
plt.tight_layout()
|
78 |
+
#plt.show()
|
79 |
+
plt.savefig(f'figs/{label_name}.png', dpi=fig.dpi)
|
80 |
+
|
81 |
+
def visualize_token2head_scores(self, all_tokens, scores_mat):
|
82 |
+
fig = plt.figure(figsize=(30, 50))
|
83 |
+
for idx, scores in enumerate(scores_mat):
|
84 |
+
scores_np = np.array(scores)
|
85 |
+
ax = fig.add_subplot(6, 3, idx + 1)
|
86 |
+
# append the attention weights
|
87 |
+
im = ax.matshow(scores_np, cmap='viridis')
|
88 |
+
|
89 |
+
fontdict = {'fontsize': 20}
|
90 |
+
|
91 |
+
ax.set_xticks(range(len(all_tokens)))
|
92 |
+
ax.set_yticks(range(len(scores)))
|
93 |
+
|
94 |
+
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
|
95 |
+
ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
|
96 |
+
ax.set_xlabel('Layer {}'.format(idx + 1))
|
97 |
+
|
98 |
+
fig.colorbar(im, fraction=0.046, pad=0.04)
|
99 |
+
plt.tight_layout()
|
100 |
+
plt.show()
|
101 |
+
|
102 |
+
def plot_attn_lines(self, data, heads):
|
103 |
+
"""Plots attention maps for the given example and attention heads."""
|
104 |
+
width = 3
|
105 |
+
example_sep = 3
|
106 |
+
word_height = 1
|
107 |
+
pad = 0.1
|
108 |
+
|
109 |
+
for ei, (layer, head) in enumerate(heads):
|
110 |
+
yoffset = 1
|
111 |
+
xoffset = ei * width * example_sep
|
112 |
+
|
113 |
+
attn = data["attns"][layer][head]
|
114 |
+
attn = np.array(attn)
|
115 |
+
attn /= attn.sum(axis=-1, keepdims=True)
|
116 |
+
words = data["tokens"]
|
117 |
+
words[0] = "..."
|
118 |
+
n_words = len(words)
|
119 |
+
|
120 |
+
for position, word in enumerate(words):
|
121 |
+
plt.text(xoffset + 0, yoffset - position * word_height, word,
|
122 |
+
ha="right", va="center")
|
123 |
+
plt.text(xoffset + width, yoffset - position * word_height, word,
|
124 |
+
ha="left", va="center")
|
125 |
+
for i in range(1, n_words):
|
126 |
+
for j in range(1, n_words):
|
127 |
+
plt.plot([xoffset + pad, xoffset + width - pad],
|
128 |
+
[yoffset - word_height * i, yoffset - word_height * j],
|
129 |
+
color="blue", linewidth=1, alpha=attn[i, j])
|
130 |
+
|
131 |
+
def plot_attn_lines_concepts(self, title, examples, layer, head, color_words,
|
132 |
+
color_from=True, width=3, example_sep=3,
|
133 |
+
word_height=1, pad=0.1, hide_sep=False):
|
134 |
+
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
135 |
+
plt.figure(figsize=(4, 4))
|
136 |
+
for i, example in enumerate(examples):
|
137 |
+
yoffset = 0
|
138 |
+
if i == 0:
|
139 |
+
yoffset += (len(examples[0]["words"]) -
|
140 |
+
len(examples[1]["words"])) * word_height / 2
|
141 |
+
xoffset = i * width * example_sep
|
142 |
+
attn = example["attentions"][layer][head]
|
143 |
+
if hide_sep:
|
144 |
+
attn = np.array(attn)
|
145 |
+
attn[:, 0] = 0
|
146 |
+
attn[:, -1] = 0
|
147 |
+
attn /= attn.sum(axis=-1, keepdims=True)
|
148 |
+
|
149 |
+
words = example["words"]
|
150 |
+
n_words = len(words)
|
151 |
+
for position, word in enumerate(words):
|
152 |
+
for x, from_word in [(xoffset, True), (xoffset + width, False)]:
|
153 |
+
color = "k"
|
154 |
+
if from_word == color_from and word in color_words:
|
155 |
+
color = "#cc0000"
|
156 |
+
plt.text(x, yoffset - (position * word_height), word,
|
157 |
+
ha="right" if from_word else "left", va="center",
|
158 |
+
color=color)
|
159 |
+
|
160 |
+
for i in range(n_words):
|
161 |
+
for j in range(n_words):
|
162 |
+
color = "b"
|
163 |
+
if words[i if color_from else j] in color_words:
|
164 |
+
color = "r"
|
165 |
+
print(attn[i, j])
|
166 |
+
plt.plot([xoffset + pad, xoffset + width - pad],
|
167 |
+
[yoffset - word_height * i, yoffset - word_height * j],
|
168 |
+
color=color, linewidth=1, alpha=attn[i, j])
|
169 |
+
plt.axis("off")
|
170 |
+
plt.title(title)
|
171 |
+
plt.show()
|
172 |
+
|
173 |
+
def plot_attn_lines_concepts_ids(title, examples, layer, head,
|
174 |
+
relations_total, width=3, example_sep=3,
|
175 |
+
word_height=1, pad=0.1, hide_sep=False):
|
176 |
+
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
177 |
+
plt.clf()
|
178 |
+
plt.figure(figsize=(10, 5))
|
179 |
+
# print('relations_total:', relations_total)
|
180 |
+
# print(examples[0])
|
181 |
+
for idx, example in enumerate(examples):
|
182 |
+
yoffset = 0
|
183 |
+
if idx == 0:
|
184 |
+
yoffset += (len(examples[0]["words"]) -
|
185 |
+
len(examples[0]["words"])) * word_height / 2
|
186 |
+
xoffset = idx * width * example_sep
|
187 |
+
attn = example["attentions"][layer][head]
|
188 |
+
if hide_sep:
|
189 |
+
attn = np.array(attn)
|
190 |
+
attn[:, 0] = 0
|
191 |
+
attn[:, -1] = 0
|
192 |
+
attn /= attn.sum(axis=-1, keepdims=True)
|
193 |
+
|
194 |
+
words = example["words"]
|
195 |
+
n_words = len(words)
|
196 |
+
example_rel = relations_total[idx]
|
197 |
+
for position, word in enumerate(words):
|
198 |
+
for x, from_word in [(xoffset, True), (xoffset + width, False)]:
|
199 |
+
color = "k"
|
200 |
+
for y_idx, y in enumerate(words):
|
201 |
+
if from_word and example_rel[position, y_idx] > 0:
|
202 |
+
# print('outgoing', position, y_idx)
|
203 |
+
color = "r"
|
204 |
+
if not from_word and example_rel[y_idx, position] > 0:
|
205 |
+
# print('coming', position, y_idx)
|
206 |
+
color = "g"
|
207 |
+
# if from_word == color_from and word in color_words:
|
208 |
+
# color = "#cc0000"
|
209 |
+
plt.text(x, yoffset - (position * word_height), word,
|
210 |
+
ha="right" if from_word else "left", va="center",
|
211 |
+
color=color)
|
212 |
+
|
213 |
+
for i in range(n_words):
|
214 |
+
for j in range(n_words):
|
215 |
+
color = "k"
|
216 |
+
# print(i,j, example_rel[i,j])
|
217 |
+
if example_rel[i, j].item() > 0 and i <= j:
|
218 |
+
color = "r"
|
219 |
+
if example_rel[i, j].item() > 0 and i >= j:
|
220 |
+
color = "g"
|
221 |
+
plt.plot([xoffset + pad, xoffset + width - pad],
|
222 |
+
[yoffset - word_height * i, yoffset - word_height * j],
|
223 |
+
color=color, linewidth=1, alpha=attn[i, j])
|
224 |
+
# color=color, linewidth=1, alpha=min(attn[i, j]*10,1))
|
225 |
+
plt.axis("off")
|
226 |
+
plt.title(title)
|
227 |
+
plt.show()
|
custom_bart/bart_attention.py
CHANGED
@@ -94,7 +94,7 @@ class BartCustomAttention(nn.Module):
|
|
94 |
# TODO
|
95 |
print('oh no')
|
96 |
relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
|
97 |
-
print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
|
98 |
assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
|
99 |
|
100 |
# (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
|
|
|
94 |
# TODO
|
95 |
print('oh no')
|
96 |
relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
|
97 |
+
#print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
|
98 |
assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
|
99 |
|
100 |
# (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
|