taka-yamakoshi commited on
Commit
c92f27c
1 Parent(s): 5d85885

add custom model

Browse files
Files changed (1) hide show
  1. custom_modeling_albert_flax.py +471 -0
custom_modeling_albert_flax.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
10
+ from flax.linen.attention import dot_product_attention_weights
11
+ from flax.traverse_util import flatten_dict, unflatten_dict
12
+ from jax import lax
13
+
14
+ from transformers import AlbertConfig
15
+ from transformers.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings
16
+ from transformers.modeling_flax_outputs import (
17
+ FlaxBaseModelOutput,
18
+ FlaxBaseModelOutputWithPooling,
19
+ FlaxMaskedLMOutput,
20
+ FlaxMultipleChoiceModelOutput,
21
+ FlaxQuestionAnsweringModelOutput,
22
+ FlaxSequenceClassifierOutput,
23
+ FlaxTokenClassifierOutput,
24
+ )
25
+ from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
26
+
27
+ from transformers.modeling_flax_utils import (
28
+ ACT2FN,
29
+ FlaxPreTrainedModel,
30
+ append_call_sample_docstring,
31
+ append_replace_return_docstrings,
32
+ overwrite_call_docstring,
33
+ )
34
+
35
+ class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
36
+ module_class = CustomFlaxAlbertForMaskedLMModule
37
+
38
+ class CustomFlaxAlbertForMaskedLMModule(nn.Module):
39
+ config: AlbertConfig
40
+ dtype: jnp.dtype = jnp.float32
41
+
42
+ def setup(self):
43
+ self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
44
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
45
+
46
+ def __call__(
47
+ self,
48
+ input_ids,
49
+ attention_mask,
50
+ token_type_ids,
51
+ position_ids,
52
+ deterministic: bool = True,
53
+ output_attentions: bool = False,
54
+ output_hidden_states: bool = False,
55
+ return_dict: bool = True,
56
+ interv_type: str = "swap",
57
+ interv_dict: dict = {},
58
+ ):
59
+ # Model
60
+ outputs = self.albert(
61
+ input_ids,
62
+ attention_mask,
63
+ token_type_ids,
64
+ position_ids,
65
+ deterministic=deterministic,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ interv_type=interv_type,
70
+ interv_dict=interv_dict,
71
+ )
72
+
73
+ hidden_states = outputs[0]
74
+ if self.config.tie_word_embeddings:
75
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
76
+ else:
77
+ shared_embedding = None
78
+
79
+ # Compute the prediction scores
80
+ logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
81
+
82
+ if not return_dict:
83
+ return (logits,) + outputs[1:]
84
+
85
+ return FlaxMaskedLMOutput(
86
+ logits=logits,
87
+ hidden_states=outputs.hidden_states,
88
+ attentions=outputs.attentions,
89
+ )
90
+
91
+ class CustomFlaxAlbertModule(nn.Module):
92
+ config: AlbertConfig
93
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
94
+ add_pooling_layer: bool = True
95
+
96
+ def setup(self):
97
+ self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
98
+ self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
99
+ if self.add_pooling_layer:
100
+ self.pooler = nn.Dense(
101
+ self.config.hidden_size,
102
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
103
+ dtype=self.dtype,
104
+ name="pooler",
105
+ )
106
+ self.pooler_activation = nn.tanh
107
+ else:
108
+ self.pooler = None
109
+ self.pooler_activation = None
110
+
111
+ def __call__(
112
+ self,
113
+ input_ids,
114
+ attention_mask,
115
+ token_type_ids: Optional[np.ndarray] = None,
116
+ position_ids: Optional[np.ndarray] = None,
117
+ deterministic: bool = True,
118
+ output_attentions: bool = False,
119
+ output_hidden_states: bool = False,
120
+ return_dict: bool = True,
121
+ interv_type: str = "swap",
122
+ interv_dict: dict = {},
123
+ ):
124
+ # make sure `token_type_ids` is correctly initialized when not passed
125
+ if token_type_ids is None:
126
+ token_type_ids = jnp.zeros_like(input_ids)
127
+
128
+ # make sure `position_ids` is correctly initialized when not passed
129
+ if position_ids is None:
130
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
131
+
132
+ hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
133
+
134
+ outputs = self.encoder(
135
+ hidden_states,
136
+ attention_mask,
137
+ deterministic=deterministic,
138
+ output_attentions=output_attentions,
139
+ output_hidden_states=output_hidden_states,
140
+ return_dict=return_dict,
141
+ interv_type=interv_type,
142
+ interv_dict=interv_dict,
143
+ )
144
+ hidden_states = outputs[0]
145
+ if self.add_pooling_layer:
146
+ pooled = self.pooler(hidden_states[:, 0])
147
+ pooled = self.pooler_activation(pooled)
148
+ else:
149
+ pooled = None
150
+
151
+ if not return_dict:
152
+ # if pooled is None, don't return it
153
+ if pooled is None:
154
+ return (hidden_states,) + outputs[1:]
155
+ return (hidden_states, pooled) + outputs[1:]
156
+
157
+ return FlaxBaseModelOutputWithPooling(
158
+ last_hidden_state=hidden_states,
159
+ pooler_output=pooled,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
163
+
164
+ class CustomFlaxAlbertEncoder(nn.Module):
165
+ config: AlbertConfig
166
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
167
+
168
+ def setup(self):
169
+ self.embedding_hidden_mapping_in = nn.Dense(
170
+ self.config.hidden_size,
171
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
172
+ dtype=self.dtype,
173
+ )
174
+ self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)
175
+
176
+ def __call__(
177
+ self,
178
+ hidden_states,
179
+ attention_mask,
180
+ deterministic: bool = True,
181
+ output_attentions: bool = False,
182
+ output_hidden_states: bool = False,
183
+ return_dict: bool = True,
184
+ interv_type: str = "swap",
185
+ interv_dict: dict = {},
186
+ ):
187
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
188
+ return self.albert_layer_groups(
189
+ hidden_states,
190
+ attention_mask,
191
+ deterministic=deterministic,
192
+ output_attentions=output_attentions,
193
+ output_hidden_states=output_hidden_states,
194
+ interv_type=interv_type,
195
+ interv_dict=interv_dict,
196
+ )
197
+
198
+ class CustomFlaxAlbertLayerGroups(nn.Module):
199
+ config: AlbertConfig
200
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
201
+
202
+ def setup(self):
203
+ self.layers = [
204
+ CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
205
+ for i in range(self.config.num_hidden_groups)
206
+ ]
207
+
208
+ def __call__(
209
+ self,
210
+ hidden_states,
211
+ attention_mask,
212
+ deterministic: bool = True,
213
+ output_attentions: bool = False,
214
+ output_hidden_states: bool = False,
215
+ return_dict: bool = True,
216
+ interv_type: str = "swap",
217
+ interv_dict: dict = {},
218
+ ):
219
+ all_attentions = () if output_attentions else None
220
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
221
+
222
+ for i in range(self.config.num_hidden_layers):
223
+ # Index of the hidden group
224
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
225
+ layer_group_output = self.layers[group_idx](
226
+ hidden_states,
227
+ attention_mask,
228
+ deterministic=deterministic,
229
+ output_attentions=output_attentions,
230
+ output_hidden_states=output_hidden_states,
231
+ layer_id=i,
232
+ interv_type=interv_type,
233
+ interv_dict=interv_dict,
234
+ )
235
+ hidden_states = layer_group_output[0]
236
+
237
+ if output_attentions:
238
+ all_attentions = all_attentions + layer_group_output[-1]
239
+
240
+ if output_hidden_states:
241
+ all_hidden_states = all_hidden_states + (hidden_states,)
242
+
243
+ if not return_dict:
244
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
245
+ return FlaxBaseModelOutput(
246
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
247
+ )
248
+
249
+ class CustomFlaxAlbertLayerCollections(nn.Module):
250
+ config: AlbertConfig
251
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
252
+ layer_index: Optional[str] = None
253
+
254
+ def setup(self):
255
+ self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype)
256
+
257
+ def __call__(
258
+ self,
259
+ hidden_states,
260
+ attention_mask,
261
+ deterministic: bool = True,
262
+ output_attentions: bool = False,
263
+ output_hidden_states: bool = False,
264
+ layer_id: int = None,
265
+ interv_type: str = "swap",
266
+ interv_dict: dict = {},
267
+ ):
268
+ outputs = self.albert_layers(
269
+ hidden_states,
270
+ attention_mask,
271
+ deterministic=deterministic,
272
+ output_attentions=output_attentions,
273
+ output_hidden_states=output_hidden_states,
274
+ layer_id=layer_id,
275
+ interv_type=interv_type,
276
+ interv_dict=interv_dict,
277
+ )
278
+ return outputs
279
+
280
+ class CustomFlaxAlbertLayerCollection(nn.Module):
281
+ config: AlbertConfig
282
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
283
+
284
+ def setup(self):
285
+ self.layers = [
286
+ CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
287
+ ]
288
+
289
+ def __call__(
290
+ self,
291
+ hidden_states,
292
+ attention_mask,
293
+ deterministic: bool = True,
294
+ output_attentions: bool = False,
295
+ output_hidden_states: bool = False,
296
+ layer_id: int = None,
297
+ interv_type: str = "swap",
298
+ interv_dict: dict = {},
299
+ ):
300
+ layer_hidden_states = ()
301
+ layer_attentions = ()
302
+
303
+ for layer_index, albert_layer in enumerate(self.layers):
304
+ layer_output = albert_layer(
305
+ hidden_states,
306
+ attention_mask,
307
+ deterministic=deterministic,
308
+ output_attentions=output_attentions,
309
+ layer_id=layer_id,
310
+ interv_type=interv_type,
311
+ interv_dict=interv_dict,
312
+ )
313
+ hidden_states = layer_output[0]
314
+
315
+ if output_attentions:
316
+ layer_attentions = layer_attentions + (layer_output[1],)
317
+
318
+ if output_hidden_states:
319
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
320
+
321
+ outputs = (hidden_states,)
322
+ if output_hidden_states:
323
+ outputs = outputs + (layer_hidden_states,)
324
+ if output_attentions:
325
+ outputs = outputs + (layer_attentions,)
326
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
327
+
328
+ class CustomFlaxAlbertLayer(nn.Module):
329
+ config: AlbertConfig
330
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
331
+
332
+ def setup(self):
333
+ self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
334
+ self.ffn = nn.Dense(
335
+ self.config.intermediate_size,
336
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
337
+ dtype=self.dtype,
338
+ )
339
+ self.activation = ACT2FN[self.config.hidden_act]
340
+ self.ffn_output = nn.Dense(
341
+ self.config.hidden_size,
342
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
343
+ dtype=self.dtype,
344
+ )
345
+ self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
346
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
347
+
348
+ def __call__(
349
+ self,
350
+ hidden_states,
351
+ attention_mask,
352
+ deterministic: bool = True,
353
+ output_attentions: bool = False,
354
+ layer_id: int = None,
355
+ interv_type: str = "swap",
356
+ interv_dict: dict = {},
357
+ ):
358
+ attention_outputs = self.attention(
359
+ hidden_states,
360
+ attention_mask,
361
+ deterministic=deterministic,
362
+ output_attentions=output_attentions,
363
+ layer_id=layer_id,
364
+ interv_type=interv_type,
365
+ interv_dict=interv_dict,
366
+ )
367
+ attention_output = attention_outputs[0]
368
+ ffn_output = self.ffn(attention_output)
369
+ ffn_output = self.activation(ffn_output)
370
+ ffn_output = self.ffn_output(ffn_output)
371
+ ffn_output = self.dropout(ffn_output, deterministic=deterministic)
372
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
373
+
374
+ outputs = (hidden_states,)
375
+
376
+ if output_attentions:
377
+ outputs += (attention_outputs[1],)
378
+ return outputs
379
+
380
+ class CustomFlaxAlbertSelfAttention(nn.Module):
381
+ config: AlbertConfig
382
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
383
+
384
+ def setup(self):
385
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
386
+ raise ValueError(
387
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
388
+ " : {self.config.num_attention_heads}"
389
+ )
390
+
391
+ self.query = nn.Dense(
392
+ self.config.hidden_size,
393
+ dtype=self.dtype,
394
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
395
+ )
396
+ self.key = nn.Dense(
397
+ self.config.hidden_size,
398
+ dtype=self.dtype,
399
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
400
+ )
401
+ self.value = nn.Dense(
402
+ self.config.hidden_size,
403
+ dtype=self.dtype,
404
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
405
+ )
406
+ self.dense = nn.Dense(
407
+ self.config.hidden_size,
408
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
409
+ dtype=self.dtype,
410
+ )
411
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
412
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
413
+
414
+ def __call__(
415
+ self,
416
+ hidden_states,
417
+ attention_mask,
418
+ deterministic=True,
419
+ output_attentions: bool = False,
420
+ layer_id: int = None,
421
+ interv_type: str = "swap",
422
+ interv_dict: dict = {},
423
+ ):
424
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
425
+
426
+ query_states = self.query(hidden_states).reshape(
427
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
428
+ )
429
+ value_states = self.value(hidden_states).reshape(
430
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
431
+ )
432
+ key_states = self.key(hidden_states).reshape(
433
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
434
+ )
435
+
436
+ # Convert the boolean attention mask to an attention bias.
437
+ if attention_mask is not None:
438
+ # attention mask in the form of attention bias
439
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
440
+ attention_bias = lax.select(
441
+ attention_mask > 0,
442
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
443
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
444
+ )
445
+ else:
446
+ attention_bias = None
447
+
448
+ dropout_rng = None
449
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
450
+ dropout_rng = self.make_rng("dropout")
451
+
452
+ attn_weights = dot_product_attention_weights(
453
+ query_states,
454
+ key_states,
455
+ bias=attention_bias,
456
+ dropout_rng=dropout_rng,
457
+ dropout_rate=self.config.attention_probs_dropout_prob,
458
+ broadcast_dropout=True,
459
+ deterministic=deterministic,
460
+ dtype=self.dtype,
461
+ precision=None,
462
+ )
463
+
464
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
465
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
466
+
467
+ projected_attn_output = self.dense(attn_output)
468
+ projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
469
+ layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
470
+ outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
471
+ return outputs