Abdullah-Nazhat
commited on
Update linearizer.py
Browse files- linearizer.py +6 -6
linearizer.py
CHANGED
@@ -7,7 +7,7 @@ import math
|
|
7 |
|
8 |
|
9 |
|
10 |
-
|
11 |
|
12 |
def default(val, default_val):
|
13 |
return val if val is not None else default_val
|
@@ -18,7 +18,7 @@ def init_(tensor):
|
|
18 |
tensor.uniform_(-std, std)
|
19 |
return tensor
|
20 |
|
21 |
-
|
22 |
|
23 |
class Residual(nn.Module):
|
24 |
def __init__(self, fn):
|
@@ -110,25 +110,25 @@ class LinformerSelfAttention(nn.Module):
|
|
110 |
|
111 |
kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
|
112 |
|
113 |
-
|
114 |
|
115 |
keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
|
116 |
|
117 |
-
|
118 |
|
119 |
queries = queries.reshape(b, n, h, -1).transpose(1, 2)
|
120 |
|
121 |
merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
|
122 |
keys, values = map(merge_key_values, (keys, values))
|
123 |
|
124 |
-
|
125 |
|
126 |
dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
|
127 |
attn = dots.softmax(dim=-1)
|
128 |
attn = self.dropout(attn)
|
129 |
out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
|
130 |
|
131 |
-
|
132 |
out = out.transpose(1, 2).reshape(b, n, -1)
|
133 |
return self.to_out(out)
|
134 |
|
|
|
7 |
|
8 |
|
9 |
|
10 |
+
|
11 |
|
12 |
def default(val, default_val):
|
13 |
return val if val is not None else default_val
|
|
|
18 |
tensor.uniform_(-std, std)
|
19 |
return tensor
|
20 |
|
21 |
+
|
22 |
|
23 |
class Residual(nn.Module):
|
24 |
def __init__(self, fn):
|
|
|
110 |
|
111 |
kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
|
112 |
|
113 |
+
|
114 |
|
115 |
keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
|
116 |
|
117 |
+
|
118 |
|
119 |
queries = queries.reshape(b, n, h, -1).transpose(1, 2)
|
120 |
|
121 |
merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
|
122 |
keys, values = map(merge_key_values, (keys, values))
|
123 |
|
124 |
+
|
125 |
|
126 |
dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
|
127 |
attn = dots.softmax(dim=-1)
|
128 |
attn = self.dropout(attn)
|
129 |
out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
|
130 |
|
131 |
+
|
132 |
out = out.transpose(1, 2).reshape(b, n, -1)
|
133 |
return self.to_out(out)
|
134 |
|