Crystalcareai commited on
Commit
91c3532
1 Parent(s): ef739f9

Upload modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +64 -64
modeling_gemmoe.py CHANGED
@@ -13,7 +13,7 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
- """ PyTorch Gemma model."""
17
 
18
  import math
19
  import warnings
@@ -43,7 +43,7 @@ from transformers.utils import (
43
  replace_return_docstrings,
44
  )
45
  from transformers.utils.import_utils import is_torch_fx_available
46
- from .configuration_gemma import GemmaConfig
47
 
48
 
49
  if is_flash_attn_2_available():
@@ -62,7 +62,7 @@ if is_torch_fx_available():
62
 
63
  logger = logging.get_logger(__name__)
64
 
65
- _CONFIG_FOR_DOC = "GemmaConfig"
66
 
67
  def approx_gelu(x):
68
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
@@ -79,7 +79,7 @@ def _get_unpad_data(attention_mask):
79
  )
80
 
81
 
82
- class GemmaRMSNorm(nn.Module):
83
  def __init__(self, dim: int, eps: float = 1e-6):
84
  super().__init__()
85
  self.eps = eps
@@ -97,10 +97,10 @@ class GemmaRMSNorm(nn.Module):
97
  return normed_x * (self.weight + 1)
98
 
99
 
100
- ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
101
 
102
 
103
- class GemmaRotaryEmbedding(nn.Module):
104
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
105
  super().__init__()
106
  self.dim = dim
@@ -164,8 +164,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
164
  return q_embed, k_embed
165
 
166
 
167
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma
168
- class GemmaMLP(nn.Module):
169
  def __init__(self, config):
170
  super().__init__()
171
  self.config = config
@@ -193,11 +193,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
193
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
194
 
195
 
196
- class GemmaAttention(nn.Module):
197
  """Multi-headed attention from 'Attention Is All You Need' paper"""
198
 
199
  # Ignore copy
200
- def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
201
  super().__init__()
202
  self.config = config
203
  self.layer_idx = layer_idx
@@ -228,7 +228,7 @@ class GemmaAttention(nn.Module):
228
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
229
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
230
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
231
- self.rotary_emb = GemmaRotaryEmbedding(
232
  self.head_dim,
233
  max_position_embeddings=self.max_position_embeddings,
234
  base=self.rope_theta,
@@ -298,10 +298,10 @@ class GemmaAttention(nn.Module):
298
  return attn_output, attn_weights, past_key_value
299
 
300
 
301
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma
302
- class GemmaFlashAttention2(GemmaAttention):
303
  """
304
- Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
305
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
306
  flash attention and deal with padding tokens in case the input contains any of them.
307
  """
@@ -363,7 +363,7 @@ class GemmaFlashAttention2(GemmaAttention):
363
  # therefore the input hidden states gets silently casted in float32. Hence, we need
364
  # cast them back in the correct dtype just to be sure everything works as expected.
365
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
366
- # in fp32. (GemmaRMSNorm handles it correctly)
367
 
368
  input_dtype = query_states.dtype
369
  if input_dtype == torch.float32:
@@ -422,7 +422,7 @@ class GemmaFlashAttention2(GemmaAttention):
422
  if not self._flash_attn_uses_top_left_mask:
423
  causal = self.is_causal
424
  else:
425
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmaFlashAttention2 __init__.
426
  causal = self.is_causal and query_length != 1
427
 
428
  # Contains at least one padding token in the sequence
@@ -495,11 +495,11 @@ class GemmaFlashAttention2(GemmaAttention):
495
  )
496
 
497
 
498
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
499
- class GemmaSdpaAttention(GemmaAttention):
500
  """
501
- Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
502
- `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
503
  SDPA API.
504
  """
505
 
@@ -517,7 +517,7 @@ class GemmaSdpaAttention(GemmaAttention):
517
  if output_attentions:
518
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
519
  logger.warning_once(
520
- "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
521
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
522
  )
523
  return super().forward(
@@ -580,24 +580,24 @@ class GemmaSdpaAttention(GemmaAttention):
580
  return attn_output, None, past_key_value
581
 
582
 
583
- GEMMA_ATTENTION_CLASSES = {
584
- "eager": GemmaAttention,
585
- "flash_attention_2": GemmaFlashAttention2,
586
- "sdpa": GemmaSdpaAttention,
587
  }
588
 
589
 
590
- # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
591
- class GemmaDecoderLayer(nn.Module):
592
- def __init__(self, config: GemmaConfig, layer_idx: int):
593
  super().__init__()
594
  self.hidden_size = config.hidden_size
595
 
596
- self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
597
 
598
- self.mlp = GemmaMLP(config)
599
- self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
600
- self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
 
602
  def forward(
603
  self,
@@ -663,7 +663,7 @@ class GemmaDecoderLayer(nn.Module):
663
  return outputs
664
 
665
 
666
- GEMMA_START_DOCSTRING = r"""
667
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
668
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
669
  etc.)
@@ -673,7 +673,7 @@ GEMMA_START_DOCSTRING = r"""
673
  and behavior.
674
 
675
  Parameters:
676
- config ([`GemmaConfig`]):
677
  Model configuration class with all the parameters of the model. Initializing with a config file does not
678
  load the weights associated with the model, only the configuration. Check out the
679
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -681,15 +681,15 @@ GEMMA_START_DOCSTRING = r"""
681
 
682
 
683
  @add_start_docstrings(
684
- "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
685
- GEMMA_START_DOCSTRING,
686
  )
687
- class GemmaPreTrainedModel(PreTrainedModel):
688
- config_class = GemmaConfig
689
  base_model_prefix = "model"
690
  supports_gradient_checkpointing = True
691
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
692
- _no_split_modules = ["GemmaDecoderLayer"]
693
  _skip_keys_device_placement = ["past_key_values", "causal_mask"]
694
  _supports_flash_attn_2 = True
695
  _supports_sdpa = True
@@ -728,7 +728,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
728
  layer.self_attn.past_key_value = None
729
 
730
 
731
- GEMMA_INPUTS_DOCSTRING = r"""
732
  Args:
733
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
734
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -803,28 +803,28 @@ GEMMA_INPUTS_DOCSTRING = r"""
803
 
804
 
805
  @add_start_docstrings(
806
- "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
807
- GEMMA_START_DOCSTRING,
808
  )
809
- # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
810
- class GemmaModel(GemmaPreTrainedModel):
811
  """
812
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
813
 
814
  Args:
815
- config: GemmaConfig
816
  """
817
 
818
- def __init__(self, config: GemmaConfig):
819
  super().__init__(config)
820
  self.padding_idx = config.pad_token_id
821
  self.vocab_size = config.vocab_size
822
 
823
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
824
  self.layers = nn.ModuleList(
825
- [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
826
  )
827
- self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
828
  self.gradient_checkpointing = False
829
 
830
  # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
@@ -842,7 +842,7 @@ class GemmaModel(GemmaPreTrainedModel):
842
  def set_input_embeddings(self, value):
843
  self.embed_tokens = value
844
 
845
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
846
  # Ignore copy
847
  def forward(
848
  self,
@@ -1021,13 +1021,13 @@ class GemmaModel(GemmaPreTrainedModel):
1021
  return causal_mask
1022
 
1023
 
1024
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma
1025
- class GemmaForCausalLM(GemmaPreTrainedModel):
1026
  _tied_weights_keys = ["lm_head.weight"]
1027
 
1028
  def __init__(self, config):
1029
  super().__init__(config)
1030
- self.model = GemmaModel(config)
1031
  self.vocab_size = config.vocab_size
1032
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1033
 
@@ -1053,7 +1053,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1053
  return self.model
1054
 
1055
  # Ignore copy
1056
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1057
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1058
  def forward(
1059
  self,
@@ -1081,10 +1081,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1081
  Example:
1082
 
1083
  ```python
1084
- >>> from transformers import AutoTokenizer, GemmaForCausalLM
1085
 
1086
- >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
1087
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
1088
 
1089
  >>> prompt = "What is your favorite condiment?"
1090
  >>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -1236,9 +1236,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1236
 
1237
  @add_start_docstrings(
1238
  """
1239
- The Gemma Model transformer with a sequence classification head on top (linear layer).
1240
 
1241
- [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1242
  (e.g. GPT-2) do.
1243
 
1244
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -1247,14 +1247,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1247
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1248
  each row of the batch).
1249
  """,
1250
- GEMMA_START_DOCSTRING,
1251
  )
1252
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma
1253
- class GemmaForSequenceClassification(GemmaPreTrainedModel):
1254
  def __init__(self, config):
1255
  super().__init__(config)
1256
  self.num_labels = config.num_labels
1257
- self.model = GemmaModel(config)
1258
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1259
 
1260
  # Initialize weights and apply final processing
@@ -1266,7 +1266,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1266
  def set_input_embeddings(self, value):
1267
  self.model.embed_tokens = value
1268
 
1269
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1270
  def forward(
1271
  self,
1272
  input_ids: torch.LongTensor = None,
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ """ PyTorch Gemmoe model."""
17
 
18
  import math
19
  import warnings
 
43
  replace_return_docstrings,
44
  )
45
  from transformers.utils.import_utils import is_torch_fx_available
46
+ from .configuration_gemmoe import GemmoeConfig
47
 
48
 
49
  if is_flash_attn_2_available():
 
62
 
63
  logger = logging.get_logger(__name__)
64
 
65
+ _CONFIG_FOR_DOC = "GemmoeConfig"
66
 
67
  def approx_gelu(x):
68
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
 
79
  )
80
 
81
 
82
+ class GemmoeRMSNorm(nn.Module):
83
  def __init__(self, dim: int, eps: float = 1e-6):
84
  super().__init__()
85
  self.eps = eps
 
97
  return normed_x * (self.weight + 1)
98
 
99
 
100
+ ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
101
 
102
 
103
+ class GemmoeRotaryEmbedding(nn.Module):
104
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
105
  super().__init__()
106
  self.dim = dim
 
164
  return q_embed, k_embed
165
 
166
 
167
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
168
+ class GemmoeMLP(nn.Module):
169
  def __init__(self, config):
170
  super().__init__()
171
  self.config = config
 
193
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
194
 
195
 
196
+ class GemmoeAttention(nn.Module):
197
  """Multi-headed attention from 'Attention Is All You Need' paper"""
198
 
199
  # Ignore copy
200
+ def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
201
  super().__init__()
202
  self.config = config
203
  self.layer_idx = layer_idx
 
228
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
229
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
230
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
231
+ self.rotary_emb = GemmoeRotaryEmbedding(
232
  self.head_dim,
233
  max_position_embeddings=self.max_position_embeddings,
234
  base=self.rope_theta,
 
298
  return attn_output, attn_weights, past_key_value
299
 
300
 
301
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemmoe
302
+ class GemmoeFlashAttention2(GemmoeAttention):
303
  """
304
+ Gemmoe flash attention module. This module inherits from `GemmoeAttention` as the weights of the module stays
305
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
306
  flash attention and deal with padding tokens in case the input contains any of them.
307
  """
 
363
  # therefore the input hidden states gets silently casted in float32. Hence, we need
364
  # cast them back in the correct dtype just to be sure everything works as expected.
365
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
366
+ # in fp32. (GemmoeRMSNorm handles it correctly)
367
 
368
  input_dtype = query_states.dtype
369
  if input_dtype == torch.float32:
 
422
  if not self._flash_attn_uses_top_left_mask:
423
  causal = self.is_causal
424
  else:
425
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmoeFlashAttention2 __init__.
426
  causal = self.is_causal and query_length != 1
427
 
428
  # Contains at least one padding token in the sequence
 
495
  )
496
 
497
 
498
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
499
+ class GemmoeSdpaAttention(GemmoeAttention):
500
  """
501
+ Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
502
+ `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
503
  SDPA API.
504
  """
505
 
 
517
  if output_attentions:
518
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
519
  logger.warning_once(
520
+ "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
521
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
522
  )
523
  return super().forward(
 
580
  return attn_output, None, past_key_value
581
 
582
 
583
+ GEMMOE_ATTENTION_CLASSES = {
584
+ "eager": GemmoeAttention,
585
+ "flash_attention_2": GemmoeFlashAttention2,
586
+ "sdpa": GemmoeSdpaAttention,
587
  }
588
 
589
 
590
+ # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
591
+ class GemmoeDecoderLayer(nn.Module):
592
+ def __init__(self, config: GemmoeConfig, layer_idx: int):
593
  super().__init__()
594
  self.hidden_size = config.hidden_size
595
 
596
+ self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
597
 
598
+ self.mlp = GemmoeMLP(config)
599
+ self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
600
+ self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
 
602
  def forward(
603
  self,
 
663
  return outputs
664
 
665
 
666
+ GEMMOE_START_DOCSTRING = r"""
667
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
668
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
669
  etc.)
 
673
  and behavior.
674
 
675
  Parameters:
676
+ config ([`GemmoeConfig`]):
677
  Model configuration class with all the parameters of the model. Initializing with a config file does not
678
  load the weights associated with the model, only the configuration. Check out the
679
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
681
 
682
 
683
  @add_start_docstrings(
684
+ "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
685
+ GEMMOE_START_DOCSTRING,
686
  )
687
+ class GemmoePreTrainedModel(PreTrainedModel):
688
+ config_class = GemmoeConfig
689
  base_model_prefix = "model"
690
  supports_gradient_checkpointing = True
691
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
692
+ _no_split_modules = ["GemmoeDecoderLayer"]
693
  _skip_keys_device_placement = ["past_key_values", "causal_mask"]
694
  _supports_flash_attn_2 = True
695
  _supports_sdpa = True
 
728
  layer.self_attn.past_key_value = None
729
 
730
 
731
+ GEMMOE_INPUTS_DOCSTRING = r"""
732
  Args:
733
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
734
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
803
 
804
 
805
  @add_start_docstrings(
806
+ "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
807
+ GEMMOE_START_DOCSTRING,
808
  )
809
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMOE,Llama->Gemmoe
810
+ class GemmoeModel(GemmoePreTrainedModel):
811
  """
812
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
813
 
814
  Args:
815
+ config: GemmoeConfig
816
  """
817
 
818
+ def __init__(self, config: GemmoeConfig):
819
  super().__init__(config)
820
  self.padding_idx = config.pad_token_id
821
  self.vocab_size = config.vocab_size
822
 
823
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
824
  self.layers = nn.ModuleList(
825
+ [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
826
  )
827
+ self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
828
  self.gradient_checkpointing = False
829
 
830
  # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
 
842
  def set_input_embeddings(self, value):
843
  self.embed_tokens = value
844
 
845
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
846
  # Ignore copy
847
  def forward(
848
  self,
 
1021
  return causal_mask
1022
 
1023
 
1024
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMOE,Llama->Gemmoe,llama->gemmoe
1025
+ class GemmoeForCausalLM(GemmoePreTrainedModel):
1026
  _tied_weights_keys = ["lm_head.weight"]
1027
 
1028
  def __init__(self, config):
1029
  super().__init__(config)
1030
+ self.model = GemmoeModel(config)
1031
  self.vocab_size = config.vocab_size
1032
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1033
 
 
1053
  return self.model
1054
 
1055
  # Ignore copy
1056
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1057
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1058
  def forward(
1059
  self,
 
1081
  Example:
1082
 
1083
  ```python
1084
+ >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1085
 
1086
+ >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1087
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1088
 
1089
  >>> prompt = "What is your favorite condiment?"
1090
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1236
 
1237
  @add_start_docstrings(
1238
  """
1239
+ The Gemmoe Model transformer with a sequence classification head on top (linear layer).
1240
 
1241
+ [`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1242
  (e.g. GPT-2) do.
1243
 
1244
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
1247
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1248
  each row of the batch).
1249
  """,
1250
+ GEMMOE_START_DOCSTRING,
1251
  )
1252
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMOE,Llama->Gemmoe
1253
+ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1254
  def __init__(self, config):
1255
  super().__init__(config)
1256
  self.num_labels = config.num_labels
1257
+ self.model = GemmoeModel(config)
1258
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1259
 
1260
  # Initialize weights and apply final processing
 
1266
  def set_input_embeddings(self, value):
1267
  self.model.embed_tokens = value
1268
 
1269
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1270
  def forward(
1271
  self,
1272
  input_ids: torch.LongTensor = None,