Martijn van Beers commited on
Commit
7dff594
1 Parent(s): dc60986

Add baseline selection

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. lib/integrated_gradients.py +12 -5
app.py CHANGED
@@ -20,9 +20,9 @@ import torch
20
  ig_explainer = IntegratedGradientsExplainer()
21
  gr_explainer = GradientRolloutExplainer()
22
 
23
- def run(sent, rollout, ig):
24
  a = gr_explainer(sent, rollout)
25
- b = ig_explainer(sent, ig)
26
  return a, b
27
 
28
  examples = pandas.read_csv("examples.csv").to_numpy().tolist()
@@ -40,14 +40,15 @@ with gradio.Blocks(title="Explanations with attention rollout") as iface:
40
  rollout_result = gradio.HTML()
41
  with gradio.Column():
42
  ig_layer = gradio.Slider(minimum=0, maximum=12, value=0, step=1, label="Select IG layer")
 
43
  ig_result = gradio.HTML()
44
  gradio.Examples(examples, [sent])
45
  with gradio.Accordion("Some more details"):
46
  util.Markdown(pathlib.Path("notice.md"))
47
 
48
  rollout_layer.change(gr_explainer, [sent, rollout_layer], rollout_result)
49
- ig_layer.change(ig_explainer, [sent, ig_layer], ig_result)
50
- but.click(run, [sent, rollout_layer, ig_layer], [rollout_result, ig_result])
51
 
52
 
53
  iface.launch()
 
20
  ig_explainer = IntegratedGradientsExplainer()
21
  gr_explainer = GradientRolloutExplainer()
22
 
23
+ def run(sent, rollout, ig, ig_baseline):
24
  a = gr_explainer(sent, rollout)
25
+ b = ig_explainer(sent, ig, ig_baseline)
26
  return a, b
27
 
28
  examples = pandas.read_csv("examples.csv").to_numpy().tolist()
 
40
  rollout_result = gradio.HTML()
41
  with gradio.Column():
42
  ig_layer = gradio.Slider(minimum=0, maximum=12, value=0, step=1, label="Select IG layer")
43
+ ig_baseline = gradio.Dropdown(label="Baseline token", choices=['Unknown', 'Padding'], value="Unknown")
44
  ig_result = gradio.HTML()
45
  gradio.Examples(examples, [sent])
46
  with gradio.Accordion("Some more details"):
47
  util.Markdown(pathlib.Path("notice.md"))
48
 
49
  rollout_layer.change(gr_explainer, [sent, rollout_layer], rollout_result)
50
+ ig_layer.change(ig_explainer, [sent, ig_layer, ig_baseline], ig_result)
51
+ but.click(run, [sent, rollout_layer, ig_layer, ig_baseline], [rollout_result, ig_result])
52
 
53
 
54
  iface.launch()
lib/integrated_gradients.py CHANGED
@@ -15,7 +15,10 @@ class IntegratedGradientsExplainer:
15
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
  self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device)
17
  self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
18
- self.ref_token_id = self.tokenizer.unk_token_id
 
 
 
19
 
20
  def tokens_from_ids(self, ids):
21
  return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))
@@ -31,8 +34,12 @@ class IntegratedGradientsExplainer:
31
  attributions = attributions / torch.norm(attributions)
32
  return attributions
33
 
 
 
 
 
 
34
 
35
- def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
36
  try:
37
  output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
38
  # if index is None:
@@ -43,7 +50,7 @@ class IntegratedGradientsExplainer:
43
  attention_mask = attention_mask
44
  attributions = ablator.attribute(
45
  inputs=input_ids,
46
- baselines=self.ref_token_id,
47
  additional_forward_args=(attention_mask),
48
  target=1,
49
  n_steps=steps,
@@ -76,7 +83,7 @@ class IntegratedGradientsExplainer:
76
  )
77
  return visualize_text(vis_data_records)
78
 
79
- def __call__(self, input_text, layer):
80
  text_batch = [input_text]
81
  encoding = self.tokenizer(text_batch, return_tensors="pt")
82
  input_ids = encoding["input_ids"].to(self.device)
@@ -87,4 +94,4 @@ class IntegratedGradientsExplainer:
87
  else:
88
  layer = getattr(self.model.roberta.encoder.layer, str(layer-1))
89
 
90
- return self.build_visualization(input_ids, attention_mask, layer=layer)
 
15
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
  self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device)
17
  self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
18
+ self.baseline_map = {
19
+ 'Unknown': self.tokenizer.unk_token_id,
20
+ 'Padding': self.tokenizer.pad_token_id,
21
+ }
22
 
23
  def tokens_from_ids(self, ids):
24
  return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))
 
34
  attributions = attributions / torch.norm(attributions)
35
  return attributions
36
 
37
+ def run_attribution_model(self, input_ids, attention_mask, baseline=None, index=None, layer=None, steps=20):
38
+ if baseline is None:
39
+ baseline = self.tokenizer.unk_token_id
40
+ else:
41
+ baseline = self.baseline_map[baseline]
42
 
 
43
  try:
44
  output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
45
  # if index is None:
 
50
  attention_mask = attention_mask
51
  attributions = ablator.attribute(
52
  inputs=input_ids,
53
+ baselines=baseline,
54
  additional_forward_args=(attention_mask),
55
  target=1,
56
  n_steps=steps,
 
83
  )
84
  return visualize_text(vis_data_records)
85
 
86
+ def __call__(self, input_text, layer, baseline):
87
  text_batch = [input_text]
88
  encoding = self.tokenizer(text_batch, return_tensors="pt")
89
  input_ids = encoding["input_ids"].to(self.device)
 
94
  else:
95
  layer = getattr(self.model.roberta.encoder.layer, str(layer-1))
96
 
97
+ return self.build_visualization(input_ids, attention_mask, layer=layer, baseline=baseline)