calbors commited on
Commit
7526745
·
verified ·
1 Parent(s): 05755e4

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_phylogpn.py +12 -0
  3. modeling_phylogpn.py +251 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "PhyloGPNModel"
4
  ],
 
 
 
 
5
  "inner_dim": 480,
6
  "kernel_size": 5,
7
  "model_type": "phylogpn",
 
2
  "architectures": [
3
  "PhyloGPNModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_phylogpn.PhyloGPNConfig",
7
+ "AutoModel": "modeling_phylogpn.PhyloGPNModel"
8
+ },
9
  "inner_dim": 480,
10
  "kernel_size": 5,
11
  "model_type": "phylogpn",
configuration_phylogpn.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class PhyloGPNConfig(PretrainedConfig):
4
+ model_type = "phylogpn"
5
+
6
+ def __init__(self, outer_dim: int = 960, inner_dim: int = 480, kernel_size: int = 5, stack_size: int = 2, num_stacks: int = 20, **kwargs):
7
+ self.outer_dim = outer_dim
8
+ self.inner_dim = inner_dim
9
+ self.kernel_size = kernel_size
10
+ self.stack_size = stack_size
11
+ self.num_stacks = num_stacks
12
+ super().__init__(**kwargs)
modeling_phylogpn.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import parametrize
6
+
7
+
8
+ def check_if_involution(indices: List[int]) -> bool:
9
+ return all(indices[indices[idx]] == idx for idx in range(len(indices)))
10
+
11
+
12
+ def get_conv1d_output_length(
13
+ input_length: int, kernel_size: int, stride_size: int = 1, pad_size: int = 0, dilation_rate: int = 1
14
+ ) -> int:
15
+ return (input_length + 2 * pad_size - dilation_rate * (kernel_size - 1) - 1) // stride_size + 1
16
+
17
+
18
+ def get_involution_indices(size: int) -> List[int]:
19
+ return list(reversed(range(size)))
20
+
21
+
22
+ class RCEWeight(nn.Module):
23
+ def __init__(
24
+ self, input_involution_indices: List[int], output_involution_indices: List[int]
25
+ ):
26
+ if not check_if_involution(input_involution_indices) or not check_if_involution(
27
+ output_involution_indices):
28
+ raise ValueError(
29
+ "`input_involution_indices` and `output_involution_indices` must be involutions"
30
+ )
31
+
32
+ super().__init__()
33
+ self.input_involution_indices = input_involution_indices
34
+ self.output_involution_indices = output_involution_indices
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ output_involution_indices = torch.tensor(self.output_involution_indices, device=x.device)
38
+ input_involution_indices = torch.tensor(self.input_involution_indices, device=x.device)
39
+ return (x + x[output_involution_indices][:, input_involution_indices].flip(2)) / 2
40
+
41
+
42
+ class IEBias(nn.Module):
43
+ def __init__(self, involution_indices: List[int]):
44
+ if not check_if_involution(involution_indices):
45
+ raise ValueError("`involution_indices` must be an involution")
46
+
47
+ super().__init__()
48
+ self.involution_indices = involution_indices
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ involution_indices = torch.tensor(self.involution_indices, device=x.device)
52
+ return (x + x[involution_indices]) / 2
53
+
54
+
55
+ class IEWeight(nn.Module):
56
+ def __init__(
57
+ self, input_involution_indices: List[int], output_involution_indices: List[int]
58
+ ):
59
+ if not check_if_involution(input_involution_indices) or not check_if_involution(
60
+ output_involution_indices):
61
+ raise ValueError(
62
+ "`input_involution_indices` and `output_involution_indices` must be involutions"
63
+ )
64
+
65
+ super().__init__()
66
+ self.input_involution_indices = input_involution_indices
67
+ self.output_involution_indices = output_involution_indices
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ input_involution_indices = torch.tensor(self.input_involution_indices, device=x.device)
71
+ output_involution_indices = torch.tensor(self.output_involution_indices, device=x.device)
72
+ return (x + x[input_involution_indices][:, output_involution_indices]) / 2
73
+
74
+
75
+ class RCEByteNetBlock(nn.Module):
76
+ def __init__(
77
+ self,
78
+ outer_involution_indices: List[int],
79
+ inner_dim: int,
80
+ kernel_size: int,
81
+ dilation_rate: int = 1
82
+ ):
83
+ outer_dim = len(outer_involution_indices)
84
+
85
+ if outer_dim % 2 != 0:
86
+ raise ValueError("`outer_involution_indices` must have an even length")
87
+
88
+ if inner_dim % 2 != 0:
89
+ raise ValueError("`inner_dim` must be even")
90
+
91
+ if kernel_size % 2 == 0:
92
+ raise ValueError("`kernel_size` must be odd")
93
+
94
+ super().__init__()
95
+ inner_involution_indices = get_involution_indices(inner_dim)
96
+
97
+ layers = [
98
+ nn.GroupNorm(1, outer_dim),
99
+ nn.GELU(),
100
+ nn.Conv1d(outer_dim, inner_dim, kernel_size=1),
101
+ nn.GroupNorm(1, inner_dim),
102
+ nn.GELU(),
103
+ nn.Conv1d(inner_dim, inner_dim, kernel_size, dilation=dilation_rate),
104
+ nn.GroupNorm(1, inner_dim),
105
+ nn.GELU(),
106
+ nn.Conv1d(inner_dim, outer_dim, kernel_size=1)
107
+ ]
108
+ parametrize.register_parametrization(
109
+ layers[2], "weight",
110
+ RCEWeight(outer_involution_indices, inner_involution_indices)
111
+ )
112
+ parametrize.register_parametrization(
113
+ layers[2], "bias",
114
+ IEBias(inner_involution_indices)
115
+ )
116
+ parametrize.register_parametrization(
117
+ layers[5], "weight",
118
+ RCEWeight(inner_involution_indices, inner_involution_indices)
119
+ )
120
+ parametrize.register_parametrization(
121
+ layers[5], "bias",
122
+ IEBias(inner_involution_indices)
123
+ )
124
+ parametrize.register_parametrization(
125
+ layers[8], "weight",
126
+ RCEWeight(inner_involution_indices, outer_involution_indices)
127
+ )
128
+ parametrize.register_parametrization(
129
+ layers[8], "bias",
130
+ IEBias(outer_involution_indices)
131
+ )
132
+
133
+ self.layers = nn.Sequential(*layers)
134
+ self._kernel_size = kernel_size
135
+ self._dilation_rate = dilation_rate
136
+
137
+ @property
138
+ def kernel_size(self):
139
+ return self._kernel_size
140
+
141
+ @property
142
+ def dilation_rate(self):
143
+ return self._dilation_rate
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ input_length = x.shape[2]
147
+ output_length = get_conv1d_output_length(input_length, self.kernel_size, dilation_rate=self.dilation_rate)
148
+ a = (input_length - output_length) // 2
149
+
150
+ if a == 0:
151
+ return self.layers(x) + x
152
+
153
+ return self.layers(x) + x[:, :, a:-a]
154
+
155
+ class RCEByteNet(nn.Module):
156
+ def __init__(
157
+ self,
158
+ input_involution_indices: List[int],
159
+ output_involution_indices: List[int],
160
+ dilation_rates: List[int],
161
+ outer_dim: int,
162
+ inner_dim: int,
163
+ kernel_size: int,
164
+ pad_token_idx: Optional[int] = None,
165
+ ):
166
+ if pad_token_idx is not None and input_involution_indices[pad_token_idx] != pad_token_idx:
167
+ raise ValueError("`input_involution_indices[pad_token_idx]` must be equal to `pad_token_idx`")
168
+
169
+ super().__init__()
170
+ vocab_size = len(input_involution_indices)
171
+ outer_involution_indices = get_involution_indices(outer_dim)
172
+
173
+ self.embedding = nn.Embedding(vocab_size, outer_dim, padding_idx=pad_token_idx)
174
+ parametrize.register_parametrization(
175
+ self.embedding, "weight",
176
+ IEWeight(input_involution_indices, outer_involution_indices)
177
+ )
178
+ nn.init.normal_(self.embedding.weight, std=2**0.5)
179
+ self.embedding.weight.data[self.embedding.padding_idx].zero_()
180
+ self.embedding.requires_grad = False
181
+
182
+ blocks = []
183
+ receptive_field_size = 1
184
+
185
+ for r in dilation_rates:
186
+ blocks.append(RCEByteNetBlock(outer_involution_indices, inner_dim, kernel_size, dilation_rate=r))
187
+ receptive_field_size += (kernel_size - 1) * r
188
+
189
+ self.blocks = nn.Sequential(*blocks)
190
+
191
+ output_dim = len(output_involution_indices)
192
+ self.output_layers = nn.Sequential(
193
+ nn.GroupNorm(1, outer_dim), nn.GELU(), nn.Conv1d(outer_dim, output_dim, kernel_size=1)
194
+ )
195
+ parametrize.register_parametrization(
196
+ self.output_layers[-1], "weight",
197
+ RCEWeight(outer_involution_indices, output_involution_indices)
198
+ )
199
+ parametrize.register_parametrization(
200
+ self.output_layers[-1], "bias", IEBias(output_involution_indices)
201
+ )
202
+
203
+ self._embedding_involution_indices = outer_involution_indices
204
+
205
+ @property
206
+ def embedding_involution_indices(self):
207
+ return self._embedding_involution_indices
208
+
209
+ def get_embeddings(self, input_tensor: torch.Tensor) -> torch.Tensor:
210
+ x = self.embedding(input_tensor).swapaxes(1, 2)
211
+ return self.output_layers[0](self.blocks(x)).swapaxes(1, 2)
212
+
213
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
214
+ x = self.get_embeddings(input_tensor).swapaxes(1, 2)
215
+ return self.output_layers[1:](x).swapaxes(1, 2)
216
+
217
+
218
+ from transformers import PreTrainedModel
219
+ from .configuration_phylogpn import PhyloGPNConfig
220
+
221
+ class PhyloGPNModel(PreTrainedModel):
222
+ config_class = PhyloGPNConfig
223
+
224
+ def __init__(self, config, **kwargs):
225
+ super().__init__(config, **kwargs)
226
+
227
+ dilation_rates = config.num_stacks * [config.kernel_size**i for i in range(0, config.stack_size)]
228
+
229
+ self._model = RCEByteNet(
230
+ input_involution_indices = [3, 2, 1, 0, 4, 5],
231
+ output_involution_indices=[3, 2, 1, 0],
232
+ dilation_rates=dilation_rates,
233
+ outer_dim = config.outer_dim,
234
+ inner_dim = config.inner_dim,
235
+ kernel_size=config.kernel_size,
236
+ pad_token_idx=5
237
+ )
238
+
239
+ def get_embeddings(self, input_ids: torch.Tensor):
240
+ return self._model.get_embeddings(input_ids)
241
+
242
+ def forward(self, input_ids: torch.Tensor):
243
+ output_tensor = self._model(input_ids)
244
+ output_array = output_tensor.numpy(force=True)
245
+
246
+ results = {}
247
+
248
+ for idx, key in enumerate("ACGT"):
249
+ results[key] = output_array[:, :, idx]
250
+
251
+ return results