taka-yamakoshi commited on
Commit
e87e116
1 Parent(s): 2092dd1
Files changed (2) hide show
  1. app.py +12 -7
  2. skeleton_modeling_albert.py +74 -0
app.py CHANGED
@@ -8,14 +8,19 @@ import seaborn as sns
8
  import jax
9
  import jax.numpy as jnp
10
 
11
- from transformers import AlbertTokenizer
 
12
 
13
- from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
 
 
 
14
 
15
  @st.cache(show_spinner=True,allow_output_mutation=True)
16
  def load_model():
17
  tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
18
- model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
 
19
  return tokenizer,model
20
 
21
  def clear_data():
@@ -58,9 +63,9 @@ if __name__=='__main__':
58
  sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
59
  input_ids_1 = tokenizer(sent_1).input_ids
60
  input_ids_2 = tokenizer(sent_2).input_ids
61
- input_ids = np.array([input_ids_1,input_ids_2])
62
 
63
- outputs = model(input_ids, interv_type='swap', interv_dict = {0:{'lay':[(8,1,[0,1])]}})
64
- logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
65
- preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
66
  st.write([tokenizer.decode([token]) for token in preds])
 
8
  import jax
9
  import jax.numpy as jnp
10
 
11
+ import torch
12
+ import torch.nn.functional as F
13
 
14
+ from transformers import AlbertTokenizer, AlbertForMaskedLM
15
+
16
+ #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
17
+ from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
18
 
19
  @st.cache(show_spinner=True,allow_output_mutation=True)
20
  def load_model():
21
  tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
22
+ #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
23
+ model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
24
  return tokenizer,model
25
 
26
  def clear_data():
 
63
  sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
64
  input_ids_1 = tokenizer(sent_1).input_ids
65
  input_ids_2 = tokenizer(sent_2).input_ids
66
+ input_ids = torch.tensor([input_ids_1,input_ids_2])
67
 
68
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
69
+ logprobs = F.log_softmax(outputs.logits, dim = -1)
70
+ preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
71
  st.write([tokenizer.decode([token]) for token in preds])
skeleton_modeling_albert.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from transformers.modeling_utils import apply_chunking_to_forward
7
+
8
+ @torch.no_grad()
9
+ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
10
+ attention_layer = layer.attention
11
+ num_heads = attention_layer.num_attention_heads
12
+ head_dim = attention_layer.attention_head_size
13
+
14
+ qry = attention_layer.query(hidden)
15
+ key = attention_layer.key(hidden)
16
+ val = attention_layer.value(hidden)
17
+
18
+ # swap representations
19
+ interv_layer = interventions.pop(layer_id,None)
20
+ if interv_layer is not None:
21
+ reps = {
22
+ 'lay': hidden,
23
+ 'qry': qry,
24
+ 'key': key,
25
+ 'val': val,
26
+ }
27
+ for rep_type in ['lay','qry','key','val']:
28
+ interv_rep = interv_layer.pop(rep_type,None)
29
+ if interv_rep is not None:
30
+ new_state = reps[rep_type].clone()
31
+ for head_id, pos, swap_ids in interv_rep:
32
+ new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
33
+ new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
34
+ reps[rep_type] = new_state.clone()
35
+
36
+ hidden = reps['lay'].clone()
37
+ qry = reps['qry'].clone()
38
+ key = reps['key'].clone()
39
+ val = reps['val'].clone()
40
+
41
+
42
+ #split into multiple heads
43
+ split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
44
+ split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
45
+ split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
46
+
47
+ #calculate the attention matrix
48
+ attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
49
+
50
+ z_rep_indiv = attn_mat@split_val
51
+ z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
52
+
53
+ hidden_post_attn_res = layer.attention.dense(z_rep)+hidden
54
+ hidden_post_attn = layer.attention.LayerNorm(hidden_post_attn_res)
55
+
56
+ ffn_output = apply_chunking_to_forward(layer.ff_chunk,layer.chunk_size_feed_forward,
57
+ layer.seq_len_dim,hidden_post_attn)
58
+ new_hidden = layer.full_layer_layer_norm(ffn_output+hidden_post_attn)
59
+ return new_hidden
60
+
61
+ def SkeletonAlbertForMaskedLM(model,input_ids,interventions):
62
+ core_model = model.albert
63
+ lm_head = model.predictions
64
+ output_hidden = []
65
+ with torch.no_grad():
66
+ hidden = core_model.embeddings(input_ids)
67
+ hidden = core_model.encoder.embedding_hidden_mapping_in(hidden)
68
+ output_hidden.append(hidden)
69
+ for layer_id in range(model.config.num_hidden_layers):
70
+ layer = core_model.encoder.albert_layer_groups[0].albert_layers[0]
71
+ hidden = SkeletonAlbertLayer(layer_id,layer,hidden,interventions)
72
+ output_hidden.append(hidden)
73
+ logits = lm_head(hidden)
74
+ return {'logits':logits,'hidden_states':output_hidden}