oweller2 commited on
Commit
0953ea5
1 Parent(s): 64c9f71
Files changed (1) hide show
  1. modeling_flexbert.py +7 -9
modeling_flexbert.py CHANGED
@@ -64,14 +64,13 @@ from transformers.modeling_outputs import (
64
  ModelOutput,
65
  MultipleChoiceModelOutput,
66
  SequenceClassifierOutput,
67
- CausalLMOutput,
68
  )
69
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
70
-
71
  from .bert_padding import index_put_first_axis
72
 
73
- from .bert_layers.activation import get_act_fn
74
- from .bert_layers.attention import (
75
  FlexBertPaddedAttention,
76
  FlexBertPaddedParallelAttention,
77
  FlexBertPaddedRopeAttention,
@@ -81,15 +80,15 @@ from .bert_layers.attention import (
81
  FlexBertUnpadRopeAttention,
82
  FlexBertUnpadRopeParallelAttention,
83
  )
84
- from .bert_layers.configuration_bert import FlexBertConfig
85
- from .bert_layers.embeddings import (
86
  BertAlibiEmbeddings,
87
  FlexBertAbsoluteEmbeddings,
88
  FlexBertCompiledSansPositionEmbeddings,
89
  FlexBertSansPositionEmbeddings,
90
  get_embedding_layer,
91
  )
92
- from .bert_layers.initialization import (
93
  ModuleType,
94
  TileLinear,
95
  TileMode,
@@ -98,7 +97,7 @@ from .bert_layers.initialization import (
98
  tile_linear,
99
  tile_norm,
100
  )
101
- from .bert_layers.layers import (
102
  BertAlibiEncoder,
103
  BertPooler,
104
  BertPredictionHeadTransform,
@@ -113,7 +112,6 @@ from .bert_layers.layers import (
113
  FlexBertUnpadPreNormLayer,
114
  get_encoder_layer,
115
  )
116
- from .bert_layers.loss import get_loss_fn
117
  from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
118
  from .normalization import get_norm_layer
119
  from .padding import pad_input, unpad_input
 
64
  ModelOutput,
65
  MultipleChoiceModelOutput,
66
  SequenceClassifierOutput,
 
67
  )
68
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
70
  from .bert_padding import index_put_first_axis
71
 
72
+ from .activation import get_act_fn
73
+ from .attention import (
74
  FlexBertPaddedAttention,
75
  FlexBertPaddedParallelAttention,
76
  FlexBertPaddedRopeAttention,
 
80
  FlexBertUnpadRopeAttention,
81
  FlexBertUnpadRopeParallelAttention,
82
  )
83
+ from .configuration_bert import FlexBertConfig
84
+ from .embeddings import (
85
  BertAlibiEmbeddings,
86
  FlexBertAbsoluteEmbeddings,
87
  FlexBertCompiledSansPositionEmbeddings,
88
  FlexBertSansPositionEmbeddings,
89
  get_embedding_layer,
90
  )
91
+ from .initialization import (
92
  ModuleType,
93
  TileLinear,
94
  TileMode,
 
97
  tile_linear,
98
  tile_norm,
99
  )
100
+ from .layers import (
101
  BertAlibiEncoder,
102
  BertPooler,
103
  BertPredictionHeadTransform,
 
112
  FlexBertUnpadPreNormLayer,
113
  get_encoder_layer,
114
  )
 
115
  from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
116
  from .normalization import get_norm_layer
117
  from .padding import pad_input, unpad_input