oweller2 commited on
Commit
e44547d
1 Parent(s): 6aca308

add modeling

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +12 -12
modeling_flexbert.py CHANGED
@@ -68,10 +68,10 @@ from transformers.modeling_outputs import (
68
  )
69
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
70
 
71
- from bert_padding import index_put_first_axis
72
 
73
- from src.bert_layers.activation import get_act_fn
74
- from src.bert_layers.attention import (
75
  FlexBertPaddedAttention,
76
  FlexBertPaddedParallelAttention,
77
  FlexBertPaddedRopeAttention,
@@ -81,15 +81,15 @@ from src.bert_layers.attention import (
81
  FlexBertUnpadRopeAttention,
82
  FlexBertUnpadRopeParallelAttention,
83
  )
84
- from src.bert_layers.configuration_bert import FlexBertConfig
85
- from src.bert_layers.embeddings import (
86
  BertAlibiEmbeddings,
87
  FlexBertAbsoluteEmbeddings,
88
  FlexBertCompiledSansPositionEmbeddings,
89
  FlexBertSansPositionEmbeddings,
90
  get_embedding_layer,
91
  )
92
- from src.bert_layers.initialization import (
93
  ModuleType,
94
  TileLinear,
95
  TileMode,
@@ -98,7 +98,7 @@ from src.bert_layers.initialization import (
98
  tile_linear,
99
  tile_norm,
100
  )
101
- from src.bert_layers.layers import (
102
  BertAlibiEncoder,
103
  BertPooler,
104
  BertPredictionHeadTransform,
@@ -113,10 +113,10 @@ from src.bert_layers.layers import (
113
  FlexBertUnpadPreNormLayer,
114
  get_encoder_layer,
115
  )
116
- from src.bert_layers.loss import get_loss_fn
117
- from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
118
- from src.bert_layers.normalization import get_norm_layer
119
- from src.bert_layers.padding import pad_input, unpad_input
120
 
121
  logger = logging.getLogger(__name__)
122
 
@@ -868,7 +868,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
868
 
869
  def _init_module_weights(self, module: nn.Module):
870
  """
871
- Custom weight init of modules using src.bert_layers.initialization.init_weights
872
  Currently only supports init of embedding modules
873
  """
874
  assert isinstance(module, nn.Module)
 
68
  )
69
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
70
 
71
+ from .bert_padding import index_put_first_axis
72
 
73
+ from .bert_layers.activation import get_act_fn
74
+ from .bert_layers.attention import (
75
  FlexBertPaddedAttention,
76
  FlexBertPaddedParallelAttention,
77
  FlexBertPaddedRopeAttention,
 
81
  FlexBertUnpadRopeAttention,
82
  FlexBertUnpadRopeParallelAttention,
83
  )
84
+ from .bert_layers.configuration_bert import FlexBertConfig
85
+ from .bert_layers.embeddings import (
86
  BertAlibiEmbeddings,
87
  FlexBertAbsoluteEmbeddings,
88
  FlexBertCompiledSansPositionEmbeddings,
89
  FlexBertSansPositionEmbeddings,
90
  get_embedding_layer,
91
  )
92
+ from .bert_layers.initialization import (
93
  ModuleType,
94
  TileLinear,
95
  TileMode,
 
98
  tile_linear,
99
  tile_norm,
100
  )
101
+ from .bert_layers.layers import (
102
  BertAlibiEncoder,
103
  BertPooler,
104
  BertPredictionHeadTransform,
 
113
  FlexBertUnpadPreNormLayer,
114
  get_encoder_layer,
115
  )
116
+ from .bert_layers.loss import get_loss_fn
117
+ from .bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
118
+ from .bert_layers.normalization import get_norm_layer
119
+ from .bert_layers.padding import pad_input, unpad_input
120
 
121
  logger = logging.getLogger(__name__)
122
 
 
868
 
869
  def _init_module_weights(self, module: nn.Module):
870
  """
871
+ Custom weight init of modules using .bert_layers.initialization.init_weights
872
  Currently only supports init of embedding modules
873
  """
874
  assert isinstance(module, nn.Module)