taka-yamakoshi
commited on
Commit
•
e87e116
1
Parent(s):
2092dd1
skeleton
Browse files- app.py +12 -7
- skeleton_modeling_albert.py +74 -0
app.py
CHANGED
@@ -8,14 +8,19 @@ import seaborn as sns
|
|
8 |
import jax
|
9 |
import jax.numpy as jnp
|
10 |
|
11 |
-
|
|
|
12 |
|
13 |
-
from
|
|
|
|
|
|
|
14 |
|
15 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
16 |
def load_model():
|
17 |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
18 |
-
model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
|
|
19 |
return tokenizer,model
|
20 |
|
21 |
def clear_data():
|
@@ -58,9 +63,9 @@ if __name__=='__main__':
|
|
58 |
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
59 |
input_ids_1 = tokenizer(sent_1).input_ids
|
60 |
input_ids_2 = tokenizer(sent_2).input_ids
|
61 |
-
input_ids =
|
62 |
|
63 |
-
outputs = model
|
64 |
-
logprobs =
|
65 |
-
preds = [
|
66 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
8 |
import jax
|
9 |
import jax.numpy as jnp
|
10 |
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
|
14 |
+
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
15 |
+
|
16 |
+
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
17 |
+
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
18 |
|
19 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
20 |
def load_model():
|
21 |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
22 |
+
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
23 |
+
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
|
24 |
return tokenizer,model
|
25 |
|
26 |
def clear_data():
|
|
|
63 |
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
64 |
input_ids_1 = tokenizer(sent_1).input_ids
|
65 |
input_ids_2 = tokenizer(sent_2).input_ids
|
66 |
+
input_ids = torch.tensor([input_ids_1,input_ids_2])
|
67 |
|
68 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
|
69 |
+
logprobs = F.log_softmax(outputs.logits, dim = -1)
|
70 |
+
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
71 |
st.write([tokenizer.decode([token]) for token in preds])
|
skeleton_modeling_albert.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
from transformers.modeling_utils import apply_chunking_to_forward
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
|
10 |
+
attention_layer = layer.attention
|
11 |
+
num_heads = attention_layer.num_attention_heads
|
12 |
+
head_dim = attention_layer.attention_head_size
|
13 |
+
|
14 |
+
qry = attention_layer.query(hidden)
|
15 |
+
key = attention_layer.key(hidden)
|
16 |
+
val = attention_layer.value(hidden)
|
17 |
+
|
18 |
+
# swap representations
|
19 |
+
interv_layer = interventions.pop(layer_id,None)
|
20 |
+
if interv_layer is not None:
|
21 |
+
reps = {
|
22 |
+
'lay': hidden,
|
23 |
+
'qry': qry,
|
24 |
+
'key': key,
|
25 |
+
'val': val,
|
26 |
+
}
|
27 |
+
for rep_type in ['lay','qry','key','val']:
|
28 |
+
interv_rep = interv_layer.pop(rep_type,None)
|
29 |
+
if interv_rep is not None:
|
30 |
+
new_state = reps[rep_type].clone()
|
31 |
+
for head_id, pos, swap_ids in interv_rep:
|
32 |
+
new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
|
33 |
+
new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
|
34 |
+
reps[rep_type] = new_state.clone()
|
35 |
+
|
36 |
+
hidden = reps['lay'].clone()
|
37 |
+
qry = reps['qry'].clone()
|
38 |
+
key = reps['key'].clone()
|
39 |
+
val = reps['val'].clone()
|
40 |
+
|
41 |
+
|
42 |
+
#split into multiple heads
|
43 |
+
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
44 |
+
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
45 |
+
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
46 |
+
|
47 |
+
#calculate the attention matrix
|
48 |
+
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
|
49 |
+
|
50 |
+
z_rep_indiv = attn_mat@split_val
|
51 |
+
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
|
52 |
+
|
53 |
+
hidden_post_attn_res = layer.attention.dense(z_rep)+hidden
|
54 |
+
hidden_post_attn = layer.attention.LayerNorm(hidden_post_attn_res)
|
55 |
+
|
56 |
+
ffn_output = apply_chunking_to_forward(layer.ff_chunk,layer.chunk_size_feed_forward,
|
57 |
+
layer.seq_len_dim,hidden_post_attn)
|
58 |
+
new_hidden = layer.full_layer_layer_norm(ffn_output+hidden_post_attn)
|
59 |
+
return new_hidden
|
60 |
+
|
61 |
+
def SkeletonAlbertForMaskedLM(model,input_ids,interventions):
|
62 |
+
core_model = model.albert
|
63 |
+
lm_head = model.predictions
|
64 |
+
output_hidden = []
|
65 |
+
with torch.no_grad():
|
66 |
+
hidden = core_model.embeddings(input_ids)
|
67 |
+
hidden = core_model.encoder.embedding_hidden_mapping_in(hidden)
|
68 |
+
output_hidden.append(hidden)
|
69 |
+
for layer_id in range(model.config.num_hidden_layers):
|
70 |
+
layer = core_model.encoder.albert_layer_groups[0].albert_layers[0]
|
71 |
+
hidden = SkeletonAlbertLayer(layer_id,layer,hidden,interventions)
|
72 |
+
output_hidden.append(hidden)
|
73 |
+
logits = lm_head(hidden)
|
74 |
+
return {'logits':logits,'hidden_states':output_hidden}
|