dayyass commited on
Commit
ef7bae2
1 Parent(s): 891847a

Upload 3 files

Browse files
Files changed (3) hide show
  1. architecture.py +159 -0
  2. requirements.txt +7 -0
  3. tokenizer.py +51 -0
architecture.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class PositionalEncoding(torch.nn.Module):
7
+ """
8
+ https://pytorch.org/tutorials/beginner/transformer_tutorial.html
9
+ """
10
+
11
+ def __init__(self, d_model: int, max_len: int = 512):
12
+ super().__init__()
13
+
14
+ position = torch.arange(max_len).unsqueeze(1)
15
+ div_term = torch.exp(
16
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
17
+ )
18
+
19
+ pe = torch.zeros(max_len, d_model)
20
+ pe[:, : d_model // 2] = torch.sin(position * div_term)
21
+ pe[:, d_model // 2 :] = torch.cos(position * div_term)
22
+
23
+ self.register_buffer("pe", pe)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = x + self.pe[: x.size(0)]
27
+ return x
28
+
29
+
30
+ class MultiheadSelfAttention(torch.nn.Module):
31
+ def __init__(self, embed_dim: int, num_heads: int = 8):
32
+ super().__init__()
33
+
34
+ self.embed_dim = embed_dim
35
+ self.num_heads = num_heads
36
+
37
+ self.query = torch.nn.Linear(
38
+ in_features=embed_dim,
39
+ out_features=embed_dim,
40
+ )
41
+ self.key = torch.nn.Linear(
42
+ in_features=embed_dim,
43
+ out_features=embed_dim,
44
+ )
45
+ self.value = torch.nn.Linear(
46
+ in_features=embed_dim,
47
+ out_features=embed_dim,
48
+ )
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ q = self.query(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1)
52
+ k = self.key(x).view(x.shape[0], self.num_heads, -1).permute(1, 2, 0)
53
+ v = self.value(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1)
54
+ qk = torch.softmax(
55
+ torch.matmul(q, k) / (self.embed_dim / self.num_heads) ** 0.5,
56
+ dim=-1,
57
+ )
58
+ qkv = torch.matmul(qk, v).transpose(0, 1).reshape(x.shape[0], -1)
59
+ return qkv
60
+
61
+
62
+ class Block(torch.nn.Module):
63
+ def __init__(self, d_model: int, num_heads: int = 8, eps: float = 1e-6):
64
+ super().__init__()
65
+
66
+ self.ln1 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
67
+ self.attn = MultiheadSelfAttention(embed_dim=d_model, num_heads=num_heads)
68
+ self.ln2 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
69
+ self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model * 4)
70
+ self.linear2 = torch.nn.Linear(in_features=d_model * 4, out_features=d_model)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ ln1 = self.ln1(x)
74
+ attn = self.attn(ln1)
75
+ ln2 = self.ln2(x + attn)
76
+ mlp = self.linear2(torch.relu(self.linear1(ln2)))
77
+ return mlp + x + attn
78
+
79
+
80
+ class Head(torch.nn.Module):
81
+ def __init__(
82
+ self,
83
+ d_model: int,
84
+ eps: float = 1e-6,
85
+ ):
86
+ super().__init__()
87
+
88
+ self.d_model = d_model
89
+ self.eps = eps
90
+
91
+ self.ln = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps)
92
+ self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model)
93
+ self.linear2 = torch.nn.Linear(in_features=d_model, out_features=d_model)
94
+ self.tanh_layer = torch.nn.Linear(in_features=d_model * 2, out_features=d_model)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ ln = self.ln(x)
98
+ mlp = torch.exp(self.linear2(torch.nn.functional.elu(self.linear1(ln))))
99
+ res = torch.cat(
100
+ [
101
+ ln.sum(dim=0) / ln.shape[0],
102
+ (mlp * ln).sum(dim=0) / mlp.sum(dim=0),
103
+ ]
104
+ )
105
+ res = torch.tanh(self.tanh_layer(res))
106
+ res /= (res**2).sum() ** 0.5
107
+ res /= (res**2).sum() ** 0.5
108
+ return res
109
+
110
+
111
+ class MUSE(torch.nn.Module):
112
+ def __init__(
113
+ self,
114
+ num_embeddings: int,
115
+ embedding_dim: int,
116
+ d_model: int,
117
+ num_heads: int,
118
+ eps: float = 1e-6,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.num_embeddings = num_embeddings
123
+ self.embedding_dim = embedding_dim
124
+ self.d_model = d_model
125
+ self.num_heads = num_heads
126
+ self.eps = eps
127
+
128
+ self.embedding = torch.nn.Embedding(
129
+ num_embeddings=num_embeddings,
130
+ embedding_dim=embedding_dim,
131
+ )
132
+ self.linear = torch.nn.Linear(
133
+ in_features=embedding_dim,
134
+ out_features=d_model,
135
+ )
136
+ self.pe = PositionalEncoding(
137
+ d_model=d_model,
138
+ max_len=512, # TODO: remove hardcode
139
+ )
140
+ self.block0 = Block(d_model=d_model)
141
+ self.block1 = Block(d_model=d_model)
142
+ self.block2 = Block(d_model=d_model)
143
+ self.block3 = Block(d_model=d_model)
144
+ self.block4 = Block(d_model=d_model)
145
+ self.block5 = Block(d_model=d_model)
146
+ self.head = Head(d_model=d_model)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ x = self.embedding(x)
150
+ x = self.linear(x)
151
+ x = self.pe(x)
152
+ x = self.block0(x)
153
+ x = self.block1(x)
154
+ x = self.block2(x)
155
+ x = self.block3(x)
156
+ x = self.block4(x)
157
+ x = self.block5(x)
158
+ x = self.head(x)
159
+ return x
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ onnx==1.16.0
2
+ onnxruntime==1.18.0
3
+ onnxruntime_extensions==0.10.1
4
+ tensorflow==2.16.1
5
+ tensorflow-hub==0.16.1
6
+ tensorflow-text==2.16.1
7
+ torch==2.3.0
tokenizer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
3
+ from tensorflow.python.saved_model.loader_impl import parse_saved_model
4
+ from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer
5
+
6
+
7
+ def _get_tokenizer_from_saved_model(saved_model: SavedModel) -> SentencepieceTokenizer:
8
+ """
9
+ Get tokenizer from tf SavedModel.
10
+ :param SavedModel saved_model: tf SavedModel.
11
+ :return: tokenizer.
12
+ :rtype: SentencepieceTokenizer
13
+ """
14
+
15
+ # extract functions that contain SentencePiece somewhere in there
16
+ functions_with_sp = [
17
+ f
18
+ for f in saved_model.meta_graphs[0].graph_def.library.function
19
+ if "tokenizer" in str(f).lower()
20
+ ]
21
+
22
+ assert (
23
+ len(functions_with_sp) == 1
24
+ ), f"len(functions_with_sp) = {len(functions_with_sp)}"
25
+
26
+ # find SentencePieceOp (contains the model) in the found function
27
+ nodes_with_sp = [
28
+ n for n in functions_with_sp[0].node_def if n.op == "SentencepieceOp"
29
+ ]
30
+
31
+ assert len(nodes_with_sp) == 1, f"len(nodes_with_sp) = {len(nodes_with_sp)}"
32
+
33
+ # we can pretty much save the model into a file since it does not change
34
+ model = nodes_with_sp[0].attr["model"].s
35
+
36
+ # instantiate the model
37
+ tokenizer = SentencepieceTokenizer(model)
38
+
39
+ return tokenizer
40
+
41
+
42
+ def get_tokenizer(model_path: str) -> SentencepieceTokenizer:
43
+ tokenizer = _get_tokenizer_from_saved_model(parse_saved_model(model_path))
44
+ return tokenizer
45
+
46
+
47
+ def tokenize(
48
+ sentence: str, # TODO: add batch processing
49
+ tokenizer: SentencepieceTokenizer,
50
+ ) -> torch.Tensor:
51
+ return torch.LongTensor([1] + tokenizer.tokenize([sentence]).to_list()[0] + [2])