Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
b218eb4
1
Parent(s):
e87e116
fix
Browse files
skeleton_modeling_albert.py
CHANGED
|
@@ -10,11 +10,16 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
|
|
| 10 |
attention_layer = layer.attention
|
| 11 |
num_heads = attention_layer.num_attention_heads
|
| 12 |
head_dim = attention_layer.attention_head_size
|
|
|
|
| 13 |
|
| 14 |
qry = attention_layer.query(hidden)
|
| 15 |
key = attention_layer.key(hidden)
|
| 16 |
val = attention_layer.value(hidden)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# swap representations
|
| 19 |
interv_layer = interventions.pop(layer_id,None)
|
| 20 |
if interv_layer is not None:
|
|
@@ -29,8 +34,8 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
|
|
| 29 |
if interv_rep is not None:
|
| 30 |
new_state = reps[rep_type].clone()
|
| 31 |
for head_id, pos, swap_ids in interv_rep:
|
| 32 |
-
new_state[swap_ids[0],pos,head_id] = reps[
|
| 33 |
-
new_state[swap_ids[1],pos,head_id] = reps[
|
| 34 |
reps[rep_type] = new_state.clone()
|
| 35 |
|
| 36 |
hidden = reps['lay'].clone()
|
|
|
|
| 10 |
attention_layer = layer.attention
|
| 11 |
num_heads = attention_layer.num_attention_heads
|
| 12 |
head_dim = attention_layer.attention_head_size
|
| 13 |
+
assert num_heads*head_dim == hidden.shape[2]
|
| 14 |
|
| 15 |
qry = attention_layer.query(hidden)
|
| 16 |
key = attention_layer.key(hidden)
|
| 17 |
val = attention_layer.value(hidden)
|
| 18 |
|
| 19 |
+
assert qry.shape == hidden.shape
|
| 20 |
+
assert key.shape == hidden.shape
|
| 21 |
+
assert val.shape == hidden.shape
|
| 22 |
+
|
| 23 |
# swap representations
|
| 24 |
interv_layer = interventions.pop(layer_id,None)
|
| 25 |
if interv_layer is not None:
|
|
|
|
| 34 |
if interv_rep is not None:
|
| 35 |
new_state = reps[rep_type].clone()
|
| 36 |
for head_id, pos, swap_ids in interv_rep:
|
| 37 |
+
new_state[swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)]
|
| 38 |
+
new_state[swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)]
|
| 39 |
reps[rep_type] = new_state.clone()
|
| 40 |
|
| 41 |
hidden = reps['lay'].clone()
|