NohTow commited on
Commit
ce9aa51
1 Parent(s): e533a59

Using dict as input

Browse files
README.md CHANGED
File without changes
__init__.py CHANGED
File without changes
__pycache__/__init__.cpython-311.pyc CHANGED
File without changes
__pycache__/activation.cpython-311.pyc CHANGED
File without changes
__pycache__/attention.cpython-311.pyc CHANGED
File without changes
__pycache__/bert_padding.cpython-311.pyc CHANGED
File without changes
__pycache__/configuration_bert.cpython-311.pyc CHANGED
File without changes
__pycache__/embeddings.cpython-311.pyc CHANGED
File without changes
__pycache__/initialization.cpython-311.pyc CHANGED
File without changes
__pycache__/layers.cpython-311.pyc CHANGED
File without changes
__pycache__/mlp.cpython-311.pyc CHANGED
File without changes
__pycache__/modeling_flexbert.cpython-311.pyc CHANGED
Binary files a/__pycache__/modeling_flexbert.cpython-311.pyc and b/__pycache__/modeling_flexbert.cpython-311.pyc differ
 
__pycache__/normalization.cpython-311.pyc CHANGED
File without changes
__pycache__/padding.cpython-311.pyc CHANGED
File without changes
__pycache__/rotary.cpython-311.pyc CHANGED
File without changes
__pycache__/utils.cpython-311.pyc CHANGED
File without changes
activation.py CHANGED
File without changes
attention.py CHANGED
File without changes
bert_padding.py CHANGED
File without changes
configuration_bert.py CHANGED
File without changes
embeddings.py CHANGED
File without changes
initialization.py CHANGED
File without changes
layers.py CHANGED
File without changes
mlp.py CHANGED
File without changes
modeling_flexbert.py CHANGED
@@ -50,7 +50,7 @@ import os
50
  import sys
51
  import warnings
52
  from dataclasses import dataclass
53
- from typing import List, Optional, Tuple, Union
54
 
55
  # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
@@ -944,18 +944,36 @@ class FlexBertModel(FlexBertPreTrainedModel):
944
 
945
  def forward(
946
  self,
947
- input_ids: torch.Tensor,
948
- attention_mask: Optional[torch.Tensor] = None,
949
- position_ids: Optional[torch.Tensor] = None,
950
- indices: Optional[torch.Tensor] = None,
951
- cu_seqlens: Optional[torch.Tensor] = None,
952
- max_seqlen: Optional[int] = None,
 
953
  **kwargs,
954
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
955
- if attention_mask is None:
 
956
  attention_mask = torch.ones_like(input_ids)
957
-
 
 
 
 
 
 
958
  embedding_output = self.embeddings(input_ids, position_ids)
 
 
 
 
 
 
 
 
 
 
959
 
960
  encoder_outputs = self.encoder(
961
  hidden_states=embedding_output,
 
50
  import sys
51
  import warnings
52
  from dataclasses import dataclass
53
+ from typing import List, Optional, Tuple, Union, Dict
54
 
55
  # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 
944
 
945
  def forward(
946
  self,
947
+ features: Dict[str, torch.Tensor],
948
+ # input_ids: torch.Tensor,
949
+ # attention_mask: Optional[torch.Tensor] = None,
950
+ # position_ids: Optional[torch.Tensor] = None,
951
+ # indices: Optional[torch.Tensor] = None,
952
+ # cu_seqlens: Optional[torch.Tensor] = None,
953
+ # max_seqlen: Optional[int] = None,
954
  **kwargs,
955
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
956
+
957
+ if features["attention_mask"] is None:
958
  attention_mask = torch.ones_like(input_ids)
959
+ else:
960
+ attention_mask = features["attention_mask"]
961
+ input_ids = features["input_ids"]
962
+ if "position_ids" not in features:
963
+ position_ids = None
964
+ else:
965
+ position_ids = features["position_ids"]
966
  embedding_output = self.embeddings(input_ids, position_ids)
967
+ if "indices" not in features:
968
+ indices = None
969
+ else:
970
+ indices = features["indices"]
971
+ if "cu_seqlens" not in features:
972
+ cu_seqlens = None
973
+ else:
974
+ cu_seqlens = features["cu_seqlens"]
975
+ if "max_seqlen" not in features:
976
+ max_seqlen = None
977
 
978
  encoder_outputs = self.encoder(
979
  hidden_states=embedding_output,
normalization.py CHANGED
File without changes
options.py CHANGED
File without changes
padding.py CHANGED
File without changes
rotary.py CHANGED
File without changes
utils.py CHANGED
File without changes