LennardZuendorf commited on
Commit
67a34bd
1 Parent(s): 6ff516d

feat/fix: fixing attention bug, fixing other mistral bugs

Browse files
explanation/attention.py CHANGED
@@ -2,7 +2,8 @@
2
 
3
 
4
  # internal imports
5
- from utils import formatting as fmt
 
6
  from .markup import markup_text
7
 
8
 
@@ -10,36 +11,52 @@ from .markup import markup_text
10
  # and marked text based on attention
11
  def chat_explained(model, prompt):
12
 
13
- model.set_config({"return_dict": True})
14
-
15
  # get encoded input
16
- encoder_input_ids = model.TOKENIZER(
17
  prompt, return_tensors="pt", add_special_tokens=True
18
  ).input_ids
19
- # generate output together with attentions of the model
20
- decoder_input_ids = model.MODEL.generate(
21
- encoder_input_ids, output_attentions=True, generation_config=model.CONFIG
22
- )
23
 
24
  # get input and output text as list of strings
25
- encoder_text = fmt.format_tokens(
26
- model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
27
- )
28
- decoder_text = fmt.format_tokens(
29
- model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
30
  )
31
 
32
- averaged_attention = fmt.avg_attention(decoder_input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # format response text for clean output
35
- response_text = fmt.format_output_text(decoder_text)
36
  # setting placeholder for iFrame graphic
37
  graphic = (
38
  "<div style='text-align: center; font-family:arial;'><h4>Attention"
39
  " Visualization doesn't support an interactive graphic.</h4></div>"
40
  )
41
  # creating marked text using markup_text function and attention
42
- marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
43
 
44
  # returning response, graphic and marked text array
45
  return response_text, graphic, marked_text, None
 
2
 
3
 
4
  # internal imports
5
+ from utils import formatting as fmt, modelling as mdl
6
+ from model import mistral
7
  from .markup import markup_text
8
 
9
 
 
11
  # and marked text based on attention
12
  def chat_explained(model, prompt):
13
 
 
 
14
  # get encoded input
15
+ input_ids = model.TOKENIZER(
16
  prompt, return_tensors="pt", add_special_tokens=True
17
  ).input_ids
18
+
19
+ # generate output of the model
20
+ decoder_ids = model.MODEL.generate(input_ids, generation_config=model.CONFIG)
 
21
 
22
  # get input and output text as list of strings
23
+ input_text = fmt.format_tokens(model.TOKENIZER.convert_ids_to_tokens(input_ids[0]))
24
+ output_text = fmt.format_tokens(
25
+ model.TOKENIZER.convert_ids_to_tokens(decoder_ids[0])
 
 
26
  )
27
 
28
+ # checking if model is mistral
29
+ if type(model.MODEL) == type(mistral.MODEL):
30
+
31
+ # get attention values for the input vectors
32
+ attention_output = model.MODEL(input_ids, output_attentions=True).attentions
33
+
34
+ # averaging attention across layers and heads
35
+ attention_output = mdl.format_mistral_attention(attention_output)
36
+ averaged_attention = fmt.avg_attention(attention_output, model="mistral")
37
+
38
+ # attention visualization for godel
39
+ else:
40
+ # get attention values for the input and output vectors
41
+ # using already generated input and output
42
+ attention_output = model.MODEL(
43
+ input_ids=input_ids,
44
+ decoder_input_ids=decoder_ids,
45
+ output_attentions=True,
46
+ )
47
+
48
+ # averaging attention across layers
49
+ averaged_attention = fmt.avg_attention(attention_output, model="godel")
50
 
51
  # format response text for clean output
52
+ response_text = fmt.format_output_text(output_text)
53
  # setting placeholder for iFrame graphic
54
  graphic = (
55
  "<div style='text-align: center; font-family:arial;'><h4>Attention"
56
  " Visualization doesn't support an interactive graphic.</h4></div>"
57
  )
58
  # creating marked text using markup_text function and attention
59
+ marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
60
 
61
  # returning response, graphic and marked text array
62
  return response_text, graphic, marked_text, None
explanation/interpret_captum.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
 
5
  # internal imports
6
  from utils import formatting as fmt
 
7
  from .markup import markup_text
8
 
9
 
@@ -26,7 +27,7 @@ def cpt_extract_seq_att(attr):
26
  def chat_explained(model, prompt):
27
  model.set_config({})
28
 
29
- # creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
30
  llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
31
 
32
  # generation attribution
@@ -48,7 +49,11 @@ def chat_explained(model, prompt):
48
  graphic = """<div style='text-align: center; font-family:arial;'><h4>
49
  Intepretation with Captum doesn't support an interactive graphic.</h4></div>
50
  """
 
51
  marked_text = markup_text(input_tokens, values, variant="captum")
52
 
 
 
 
53
  # return response, graphic and marked_text array
54
- return response_text, graphic, marked_text, None
 
4
 
5
  # internal imports
6
  from utils import formatting as fmt
7
+ from .plotting import plot_seq
8
  from .markup import markup_text
9
 
10
 
 
27
  def chat_explained(model, prompt):
28
  model.set_config({})
29
 
30
+ # creating llm attribution class with KernelSHAP and Mistral Model, Tokenizer
31
  llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
32
 
33
  # generation attribution
 
49
  graphic = """<div style='text-align: center; font-family:arial;'><h4>
50
  Intepretation with Captum doesn't support an interactive graphic.</h4></div>
51
  """
52
+ # create the explanation marked text array
53
  marked_text = markup_text(input_tokens, values, variant="captum")
54
 
55
+ # creating sequence attribution plot
56
+ plot = plot_seq(cpt_extract_seq_att(attribution_result), "KernelSHAP")
57
+
58
  # return response, graphic and marked_text array
59
+ return response_text, graphic, marked_text, plot
explanation/interpret_shap.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
 
7
  # internal imports
8
  from utils import formatting as fmt
 
9
  from .markup import markup_text
10
 
11
  # global variables
@@ -14,7 +15,7 @@ TEXT_MASKER = None
14
 
15
 
16
  # function to extract summarized sequence wise attribution
17
- def extract_seq_att(shap_values):
18
 
19
  # extracting summed up shap values
20
  values = fmt.flatten_attribution(shap_values.values[0], 1)
@@ -78,5 +79,8 @@ def chat_explained(model, prompt):
78
  # create the response text
79
  response_text = fmt.format_output_text(shap_values.output_names)
80
 
 
 
 
81
  # return response, graphic and marked_text array
82
- return response_text, graphic, marked_text, None
 
6
 
7
  # internal imports
8
  from utils import formatting as fmt
9
+ from .plotting import plot_seq
10
  from .markup import markup_text
11
 
12
  # global variables
 
15
 
16
 
17
  # function to extract summarized sequence wise attribution
18
+ def shap_extract_seq_att(shap_values):
19
 
20
  # extracting summed up shap values
21
  values = fmt.flatten_attribution(shap_values.values[0], 1)
 
79
  # create the response text
80
  response_text = fmt.format_output_text(shap_values.output_names)
81
 
82
+ # creating sequence attribution plot
83
+ plot = plot_seq(shap_extract_seq_att(shap_values), "PartitionSHAP")
84
+
85
  # return response, graphic and marked_text array
86
+ return response_text, graphic, marked_text, plot
explanation/markup.py CHANGED
@@ -25,12 +25,12 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
25
  min_val, max_val = np.min(text_values), np.max(text_values)
26
 
27
  # separate the threshold calculation for negative and positive values
28
- # visualization negative thresholds are all 0 since attetion always positive
29
  if variant == "visualizer":
30
  neg_thresholds = np.linspace(
31
  0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
32
  )[1:]
33
- # standart config for 5 negative buckets
34
  else:
35
  neg_thresholds = np.linspace(
36
  min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
@@ -45,16 +45,19 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
45
 
46
  # looping over each text snippet and attribution value
47
  for text, value in zip(input_text, text_values):
48
- # setting inital bucket at lowest
49
- bucket = "-5"
50
 
51
- # looping over all bucket and their threshold
52
- for i, threshold in zip(bucket_tags, thresholds):
53
- # updating assigned bucket if value is above threshold
54
- if value >= threshold:
55
- bucket = i
56
- # finally adding text and bucket assignment to list of tuples
57
- marked_text.append((text, str(bucket)))
 
 
 
 
 
58
 
59
  # returning list of marked text snippets as list of tuples
60
  return marked_text
 
25
  min_val, max_val = np.min(text_values), np.max(text_values)
26
 
27
  # separate the threshold calculation for negative and positive values
28
+ # visualization negative thresholds are all 0 since attention always positive
29
  if variant == "visualizer":
30
  neg_thresholds = np.linspace(
31
  0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
32
  )[1:]
33
+ # standard config for 5 negative buckets
34
  else:
35
  neg_thresholds = np.linspace(
36
  min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
 
45
 
46
  # looping over each text snippet and attribution value
47
  for text, value in zip(input_text, text_values):
 
 
48
 
49
+ # validating text and skipping empty text/special tokens
50
+ if text not in ("", fmt.SPECIAL_TOKENS):
51
+ # setting initial bucket at lowest
52
+ bucket = "-5"
53
+
54
+ # looping over all bucket and their threshold
55
+ for i, threshold in zip(bucket_tags, thresholds):
56
+ # updating assigned bucket if value is above threshold
57
+ if value >= threshold:
58
+ bucket = i
59
+ # finally adding text and bucket assignment to list of tuples
60
+ marked_text.append((text, str(bucket)))
61
 
62
  # returning list of marked text snippets as list of tuples
63
  return marked_text
explanation/plotting.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import matplotlib.pyplot as plt
6
 
7
 
8
- def plot_seq(seq_values: list, method_model: tuple = ("", "")):
9
 
10
  # Separate the tokens and their corresponding importance values
11
  tokens, importance = zip(*seq_values)
@@ -45,7 +45,7 @@ def plot_seq(seq_values: list, method_model: tuple = ("", "")):
45
  )
46
 
47
  plt.axhline(0, color="black", linewidth=1)
48
- plt.title(f"Input Token Attribution with {method_model[0]} on {method_model[1]}")
49
  plt.xlabel("Input Tokens", labelpad=0.5)
50
  plt.ylabel("Attribution")
51
  plt.xticks(x_positions, tokens, rotation=45)
 
5
  import matplotlib.pyplot as plt
6
 
7
 
8
+ def plot_seq(seq_values: list, method: str = ""):
9
 
10
  # Separate the tokens and their corresponding importance values
11
  tokens, importance = zip(*seq_values)
 
45
  )
46
 
47
  plt.axhline(0, color="black", linewidth=1)
48
+ plt.title(f"Input Token Attribution with {method}")
49
  plt.xlabel("Input Tokens", labelpad=0.5)
50
  plt.ylabel("Attribution")
51
  plt.xticks(x_positions, tokens, rotation=45)
main.py CHANGED
@@ -155,7 +155,7 @@ with gr.Blocks(
155
  The explanations are based on 10 buckets that range between the
156
  lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
157
  **The legend shows the color for each bucket.**
158
-
159
  *HINT*: This works best in light mode.
160
  """)
161
  xai_text = gr.HighlightedText(
@@ -210,12 +210,34 @@ with gr.Blocks(
210
  gr.Examples(
211
  label="Example Questions",
212
  examples=[
213
- ["Does money buy happiness?", "", "Mistral", "SHAP"],
214
- ["Does money buy happiness?", "", "Mistral", "Attention"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  ],
216
  inputs=[
217
  user_prompt,
218
  knowledge_input,
 
219
  model_selection,
220
  xai_selection,
221
  ],
@@ -227,32 +249,21 @@ with gr.Blocks(
227
  label="Example Questions",
228
  examples=[
229
  [
230
- "How does a black hole form in space?",
231
  (
232
  "Black holes are created when a massive star's core"
233
  " collapses after a supernova, forming an object with"
234
  " gravity so intense that even light cannot escape."
235
  ),
 
236
  "GODEL",
237
  "SHAP",
238
  ],
239
- [
240
- (
241
- "Explain the importance of the Rosetta Stone in"
242
- " understanding ancient languages."
243
- ),
244
- (
245
- "The Rosetta Stone, an ancient Egyptian artifact, was"
246
- " key in decoding hieroglyphs, featuring the same text"
247
- " in three scripts: hieroglyphs, Demotic, and Greek."
248
- ),
249
- "GODEL",
250
- "Attention",
251
- ],
252
  ],
253
  inputs=[
254
  user_prompt,
255
  knowledge_input,
 
256
  model_selection,
257
  xai_selection,
258
  ],
 
155
  The explanations are based on 10 buckets that range between the
156
  lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
157
  **The legend shows the color for each bucket.**
158
+
159
  *HINT*: This works best in light mode.
160
  """)
161
  xai_text = gr.HighlightedText(
 
210
  gr.Examples(
211
  label="Example Questions",
212
  examples=[
213
+ ["Does money buy happiness?", "", "", "Mistral", "None"],
214
+ ["Does money buy happiness?", "", "", "Mistral", "SHAP"],
215
+ ["Does money buy happiness?", "", "", "Mistral", "Attention"],
216
+ [
217
+ "Does money buy happiness?",
218
+ "",
219
+ (
220
+ "Respond from the perspective of a billionaire enjoying"
221
+ " life in Dubai"
222
+ ),
223
+ "Mistral",
224
+ "None",
225
+ ],
226
+ [
227
+ "Does money buy happiness?",
228
+ "",
229
+ (
230
+ "Respond from the perspective of a billionaire enjoying"
231
+ " life in Dubai"
232
+ ),
233
+ "Mistral",
234
+ "SHAP",
235
+ ],
236
  ],
237
  inputs=[
238
  user_prompt,
239
  knowledge_input,
240
+ system_prompt,
241
  model_selection,
242
  xai_selection,
243
  ],
 
249
  label="Example Questions",
250
  examples=[
251
  [
252
+ "Does money buy happiness?",
253
  (
254
  "Black holes are created when a massive star's core"
255
  " collapses after a supernova, forming an object with"
256
  " gravity so intense that even light cannot escape."
257
  ),
258
+ "",
259
  "GODEL",
260
  "SHAP",
261
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  ],
263
  inputs=[
264
  user_prompt,
265
  knowledge_input,
266
+ system_prompt,
267
  model_selection,
268
  xai_selection,
269
  ],
pyproject.toml CHANGED
@@ -21,6 +21,7 @@ exclude = '''
21
 
22
  [tool.pylint.messages_control]
23
  disable = [
 
24
  "not-a-mapping",
25
  "arguments-differ",
26
  "attribute-defined-outside-init",
 
21
 
22
  [tool.pylint.messages_control]
23
  disable = [
24
+ "unidiomatic-typecheck",
25
  "not-a-mapping",
26
  "arguments-differ",
27
  "attribute-defined-outside-init",
utils/formatting.py CHANGED
@@ -2,12 +2,31 @@
2
 
3
  # external imports
4
  import re
 
5
  import numpy as np
6
  from numpy import ndarray
7
 
8
 
9
- # function to format the model reponse nicely
10
- # takes a list of strings and returnings a combined string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def format_output_text(output: list):
12
 
13
  # remove special tokens from list using other function
@@ -36,8 +55,6 @@ def format_output_text(output: list):
36
 
37
  # format the tokens by removing special tokens and special characters
38
  def format_tokens(tokens: list):
39
- # define special tokens to remove
40
- special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
41
 
42
  # initialize empty list
43
  updated_tokens = []
@@ -49,7 +66,7 @@ def format_tokens(tokens: list):
49
  t = t.lstrip("▁")
50
 
51
  # loop through special tokens list and remove from current token if matched
52
- for s in special_tokens:
53
  t = t.replace(s, "")
54
 
55
  # add token to list
@@ -70,6 +87,12 @@ def flatten_attention(values: ndarray, axis: int = 0):
70
 
71
 
72
  # function to get averaged decoder attention from attention values
73
- def avg_attention(attention_values):
74
- attention = attention_values.decoder_attentions[0][0].detach().numpy()
75
- return np.mean(attention, axis=0)
 
 
 
 
 
 
 
2
 
3
  # external imports
4
  import re
5
+ import torch
6
  import numpy as np
7
  from numpy import ndarray
8
 
9
 
10
+ # globally defined tokens that are removed from the output
11
+ SPECIAL_TOKENS = [
12
+ "[CLS]",
13
+ "[SEP]",
14
+ "[PAD]",
15
+ "[UNK]",
16
+ "[MASK]",
17
+ "▁",
18
+ "Ġ",
19
+ "</w>",
20
+ "<0x0A>",
21
+ "<0x0D>",
22
+ "<0x09>",
23
+ "<s>",
24
+ "</s>",
25
+ ]
26
+
27
+
28
+ # function to format the model repose nicely
29
+ # takes a list of strings and returning a combined string
30
  def format_output_text(output: list):
31
 
32
  # remove special tokens from list using other function
 
55
 
56
  # format the tokens by removing special tokens and special characters
57
  def format_tokens(tokens: list):
 
 
58
 
59
  # initialize empty list
60
  updated_tokens = []
 
66
  t = t.lstrip("▁")
67
 
68
  # loop through special tokens list and remove from current token if matched
69
+ for s in SPECIAL_TOKENS:
70
  t = t.replace(s, "")
71
 
72
  # add token to list
 
87
 
88
 
89
  # function to get averaged decoder attention from attention values
90
+ def avg_attention(attention_values, model: str):
91
+ # check if model is godel
92
+ if model == "godel":
93
+ # get attention values for the input and output vectors
94
+ attention = attention_values.decoder_attentions[0][0].detach().numpy()
95
+ return np.mean(attention, axis=0)
96
+ # extracting attention values for mistral
97
+ attention_np = attention_values.to(torch.device("cpu")).detach().numpy()
98
+ return np.mean(attention_np, axis=(0, 1, 2))
utils/modelling.py CHANGED
@@ -97,3 +97,14 @@ def gpu_loading_config(max_memory: str = "15000MB"):
97
  )
98
 
99
  return n_gpus, max_memory, bnb_config
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
  return n_gpus, max_memory, bnb_config
100
+
101
+
102
+ # formatting mistral attention values
103
+ # CREDIT: copied and adapted from BERTViz
104
+ # see https://github.com/jessevig/bertviz
105
+ def format_mistral_attention(attention_values):
106
+ squeezed = []
107
+ for layer_attention in attention_values:
108
+ layer_attention = layer_attention.squeeze(0)
109
+ squeezed.append(layer_attention)
110
+ return torch.stack(squeezed)