oweller2
commited on
Commit
•
e44547d
1
Parent(s):
6aca308
add modeling
Browse files- modeling_flexbert.py +12 -12
modeling_flexbert.py
CHANGED
@@ -68,10 +68,10 @@ from transformers.modeling_outputs import (
|
|
68 |
)
|
69 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
70 |
|
71 |
-
from bert_padding import index_put_first_axis
|
72 |
|
73 |
-
from
|
74 |
-
from
|
75 |
FlexBertPaddedAttention,
|
76 |
FlexBertPaddedParallelAttention,
|
77 |
FlexBertPaddedRopeAttention,
|
@@ -81,15 +81,15 @@ from src.bert_layers.attention import (
|
|
81 |
FlexBertUnpadRopeAttention,
|
82 |
FlexBertUnpadRopeParallelAttention,
|
83 |
)
|
84 |
-
from
|
85 |
-
from
|
86 |
BertAlibiEmbeddings,
|
87 |
FlexBertAbsoluteEmbeddings,
|
88 |
FlexBertCompiledSansPositionEmbeddings,
|
89 |
FlexBertSansPositionEmbeddings,
|
90 |
get_embedding_layer,
|
91 |
)
|
92 |
-
from
|
93 |
ModuleType,
|
94 |
TileLinear,
|
95 |
TileMode,
|
@@ -98,7 +98,7 @@ from src.bert_layers.initialization import (
|
|
98 |
tile_linear,
|
99 |
tile_norm,
|
100 |
)
|
101 |
-
from
|
102 |
BertAlibiEncoder,
|
103 |
BertPooler,
|
104 |
BertPredictionHeadTransform,
|
@@ -113,10 +113,10 @@ from src.bert_layers.layers import (
|
|
113 |
FlexBertUnpadPreNormLayer,
|
114 |
get_encoder_layer,
|
115 |
)
|
116 |
-
from
|
117 |
-
from
|
118 |
-
from
|
119 |
-
from
|
120 |
|
121 |
logger = logging.getLogger(__name__)
|
122 |
|
@@ -868,7 +868,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
868 |
|
869 |
def _init_module_weights(self, module: nn.Module):
|
870 |
"""
|
871 |
-
Custom weight init of modules using
|
872 |
Currently only supports init of embedding modules
|
873 |
"""
|
874 |
assert isinstance(module, nn.Module)
|
|
|
68 |
)
|
69 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
70 |
|
71 |
+
from .bert_padding import index_put_first_axis
|
72 |
|
73 |
+
from .bert_layers.activation import get_act_fn
|
74 |
+
from .bert_layers.attention import (
|
75 |
FlexBertPaddedAttention,
|
76 |
FlexBertPaddedParallelAttention,
|
77 |
FlexBertPaddedRopeAttention,
|
|
|
81 |
FlexBertUnpadRopeAttention,
|
82 |
FlexBertUnpadRopeParallelAttention,
|
83 |
)
|
84 |
+
from .bert_layers.configuration_bert import FlexBertConfig
|
85 |
+
from .bert_layers.embeddings import (
|
86 |
BertAlibiEmbeddings,
|
87 |
FlexBertAbsoluteEmbeddings,
|
88 |
FlexBertCompiledSansPositionEmbeddings,
|
89 |
FlexBertSansPositionEmbeddings,
|
90 |
get_embedding_layer,
|
91 |
)
|
92 |
+
from .bert_layers.initialization import (
|
93 |
ModuleType,
|
94 |
TileLinear,
|
95 |
TileMode,
|
|
|
98 |
tile_linear,
|
99 |
tile_norm,
|
100 |
)
|
101 |
+
from .bert_layers.layers import (
|
102 |
BertAlibiEncoder,
|
103 |
BertPooler,
|
104 |
BertPredictionHeadTransform,
|
|
|
113 |
FlexBertUnpadPreNormLayer,
|
114 |
get_encoder_layer,
|
115 |
)
|
116 |
+
from .bert_layers.loss import get_loss_fn
|
117 |
+
from .bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
118 |
+
from .bert_layers.normalization import get_norm_layer
|
119 |
+
from .bert_layers.padding import pad_input, unpad_input
|
120 |
|
121 |
logger = logging.getLogger(__name__)
|
122 |
|
|
|
868 |
|
869 |
def _init_module_weights(self, module: nn.Module):
|
870 |
"""
|
871 |
+
Custom weight init of modules using .bert_layers.initialization.init_weights
|
872 |
Currently only supports init of embedding modules
|
873 |
"""
|
874 |
assert isinstance(module, nn.Module)
|