michael-guenther commited on
Commit
30e6a10
1 Parent(s): 2e3ebcb

change config name

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_bert.py +1 -1
  3. modeling_bert.py +3 -2
config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "auto_map": {
3
- "AutoConfig": "configuration_bert.BertConfig",
4
  "AutoModel": "modeling_bert.BertModel",
5
  "AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
6
  "AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
 
1
  {
2
  "auto_map": {
3
+ "AutoConfig": "configuration_bert.XLMFlashConfig",
4
  "AutoModel": "modeling_bert.BertModel",
5
  "AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
6
  "AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
configuration_bert.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import PretrainedConfig
2
 
3
- class BertConfig(PretrainedConfig):
4
  def __init__(
5
  self,
6
  vocab_size=30522,
 
1
  from transformers import PretrainedConfig
2
 
3
+ class XLMFlashConfig(PretrainedConfig):
4
  def __init__(
5
  self,
6
  vocab_size=30522,
modeling_bert.py CHANGED
@@ -19,7 +19,7 @@ import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from einops import rearrange
22
- from transformers import BertConfig, PretrainedConfig, XLMRobertaConfig # TODO check whether to use XLMRobertaConfig
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.models.bert.modeling_bert import (
25
  BaseModelOutputWithPoolingAndCrossAttentions,
@@ -32,6 +32,7 @@ from .bert_padding import (
32
  pad_input,
33
  unpad_input,
34
  )
 
35
  from .block import Block
36
  from .embedding import BertEmbeddings
37
  from .mha import MHA
@@ -345,7 +346,7 @@ class BertPreTrainedModel(PreTrainedModel):
345
  """An abstract class to handle weights initialization and
346
  a simple interface for dowloading and loading pretrained models.
347
  """
348
- config_class = XLMRobertaConfig
349
  base_model_prefix = "bert"
350
  supports_gradient_checkpointing = True
351
 
 
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from einops import rearrange
22
+ from transformers import BertConfig, PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.models.bert.modeling_bert import (
25
  BaseModelOutputWithPoolingAndCrossAttentions,
 
32
  pad_input,
33
  unpad_input,
34
  )
35
+ from .configuration_bert import XLMFlashConfig
36
  from .block import Block
37
  from .embedding import BertEmbeddings
38
  from .mha import MHA
 
346
  """An abstract class to handle weights initialization and
347
  a simple interface for dowloading and loading pretrained models.
348
  """
349
+ config_class = XLMFlashConfig
350
  base_model_prefix = "bert"
351
  supports_gradient_checkpointing = True
352