Crystalcareai
commited on
Commit
•
91c3532
1
Parent(s):
ef739f9
Upload modeling_gemmoe.py
Browse files- modeling_gemmoe.py +64 -64
modeling_gemmoe.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
-
""" PyTorch
|
17 |
|
18 |
import math
|
19 |
import warnings
|
@@ -43,7 +43,7 @@ from transformers.utils import (
|
|
43 |
replace_return_docstrings,
|
44 |
)
|
45 |
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
-
from .
|
47 |
|
48 |
|
49 |
if is_flash_attn_2_available():
|
@@ -62,7 +62,7 @@ if is_torch_fx_available():
|
|
62 |
|
63 |
logger = logging.get_logger(__name__)
|
64 |
|
65 |
-
_CONFIG_FOR_DOC = "
|
66 |
|
67 |
def approx_gelu(x):
|
68 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
@@ -79,7 +79,7 @@ def _get_unpad_data(attention_mask):
|
|
79 |
)
|
80 |
|
81 |
|
82 |
-
class
|
83 |
def __init__(self, dim: int, eps: float = 1e-6):
|
84 |
super().__init__()
|
85 |
self.eps = eps
|
@@ -97,10 +97,10 @@ class GemmaRMSNorm(nn.Module):
|
|
97 |
return normed_x * (self.weight + 1)
|
98 |
|
99 |
|
100 |
-
ALL_LAYERNORM_LAYERS.append(
|
101 |
|
102 |
|
103 |
-
class
|
104 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
105 |
super().__init__()
|
106 |
self.dim = dim
|
@@ -164,8 +164,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
164 |
return q_embed, k_embed
|
165 |
|
166 |
|
167 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->
|
168 |
-
class
|
169 |
def __init__(self, config):
|
170 |
super().__init__()
|
171 |
self.config = config
|
@@ -193,11 +193,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
193 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
194 |
|
195 |
|
196 |
-
class
|
197 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
198 |
|
199 |
# Ignore copy
|
200 |
-
def __init__(self, config:
|
201 |
super().__init__()
|
202 |
self.config = config
|
203 |
self.layer_idx = layer_idx
|
@@ -228,7 +228,7 @@ class GemmaAttention(nn.Module):
|
|
228 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
229 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
230 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
231 |
-
self.rotary_emb =
|
232 |
self.head_dim,
|
233 |
max_position_embeddings=self.max_position_embeddings,
|
234 |
base=self.rope_theta,
|
@@ -298,10 +298,10 @@ class GemmaAttention(nn.Module):
|
|
298 |
return attn_output, attn_weights, past_key_value
|
299 |
|
300 |
|
301 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->
|
302 |
-
class
|
303 |
"""
|
304 |
-
|
305 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
306 |
flash attention and deal with padding tokens in case the input contains any of them.
|
307 |
"""
|
@@ -363,7 +363,7 @@ class GemmaFlashAttention2(GemmaAttention):
|
|
363 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
364 |
# cast them back in the correct dtype just to be sure everything works as expected.
|
365 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
366 |
-
# in fp32. (
|
367 |
|
368 |
input_dtype = query_states.dtype
|
369 |
if input_dtype == torch.float32:
|
@@ -422,7 +422,7 @@ class GemmaFlashAttention2(GemmaAttention):
|
|
422 |
if not self._flash_attn_uses_top_left_mask:
|
423 |
causal = self.is_causal
|
424 |
else:
|
425 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in
|
426 |
causal = self.is_causal and query_length != 1
|
427 |
|
428 |
# Contains at least one padding token in the sequence
|
@@ -495,11 +495,11 @@ class GemmaFlashAttention2(GemmaAttention):
|
|
495 |
)
|
496 |
|
497 |
|
498 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->
|
499 |
-
class
|
500 |
"""
|
501 |
-
|
502 |
-
`
|
503 |
SDPA API.
|
504 |
"""
|
505 |
|
@@ -517,7 +517,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|
517 |
if output_attentions:
|
518 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
519 |
logger.warning_once(
|
520 |
-
"
|
521 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
522 |
)
|
523 |
return super().forward(
|
@@ -580,24 +580,24 @@ class GemmaSdpaAttention(GemmaAttention):
|
|
580 |
return attn_output, None, past_key_value
|
581 |
|
582 |
|
583 |
-
|
584 |
-
"eager":
|
585 |
-
"flash_attention_2":
|
586 |
-
"sdpa":
|
587 |
}
|
588 |
|
589 |
|
590 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->
|
591 |
-
class
|
592 |
-
def __init__(self, config:
|
593 |
super().__init__()
|
594 |
self.hidden_size = config.hidden_size
|
595 |
|
596 |
-
self.self_attn =
|
597 |
|
598 |
-
self.mlp =
|
599 |
-
self.input_layernorm =
|
600 |
-
self.post_attention_layernorm =
|
601 |
|
602 |
def forward(
|
603 |
self,
|
@@ -663,7 +663,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
663 |
return outputs
|
664 |
|
665 |
|
666 |
-
|
667 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
668 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
669 |
etc.)
|
@@ -673,7 +673,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
673 |
and behavior.
|
674 |
|
675 |
Parameters:
|
676 |
-
config ([`
|
677 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
678 |
load the weights associated with the model, only the configuration. Check out the
|
679 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -681,15 +681,15 @@ GEMMA_START_DOCSTRING = r"""
|
|
681 |
|
682 |
|
683 |
@add_start_docstrings(
|
684 |
-
"The bare
|
685 |
-
|
686 |
)
|
687 |
-
class
|
688 |
-
config_class =
|
689 |
base_model_prefix = "model"
|
690 |
supports_gradient_checkpointing = True
|
691 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
692 |
-
_no_split_modules = ["
|
693 |
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
694 |
_supports_flash_attn_2 = True
|
695 |
_supports_sdpa = True
|
@@ -728,7 +728,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
|
728 |
layer.self_attn.past_key_value = None
|
729 |
|
730 |
|
731 |
-
|
732 |
Args:
|
733 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
734 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
@@ -803,28 +803,28 @@ GEMMA_INPUTS_DOCSTRING = r"""
|
|
803 |
|
804 |
|
805 |
@add_start_docstrings(
|
806 |
-
"The bare
|
807 |
-
|
808 |
)
|
809 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->
|
810 |
-
class
|
811 |
"""
|
812 |
-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`
|
813 |
|
814 |
Args:
|
815 |
-
config:
|
816 |
"""
|
817 |
|
818 |
-
def __init__(self, config:
|
819 |
super().__init__(config)
|
820 |
self.padding_idx = config.pad_token_id
|
821 |
self.vocab_size = config.vocab_size
|
822 |
|
823 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
824 |
self.layers = nn.ModuleList(
|
825 |
-
[
|
826 |
)
|
827 |
-
self.norm =
|
828 |
self.gradient_checkpointing = False
|
829 |
|
830 |
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
@@ -842,7 +842,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
842 |
def set_input_embeddings(self, value):
|
843 |
self.embed_tokens = value
|
844 |
|
845 |
-
@add_start_docstrings_to_model_forward(
|
846 |
# Ignore copy
|
847 |
def forward(
|
848 |
self,
|
@@ -1021,13 +1021,13 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
1021 |
return causal_mask
|
1022 |
|
1023 |
|
1024 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->
|
1025 |
-
class
|
1026 |
_tied_weights_keys = ["lm_head.weight"]
|
1027 |
|
1028 |
def __init__(self, config):
|
1029 |
super().__init__(config)
|
1030 |
-
self.model =
|
1031 |
self.vocab_size = config.vocab_size
|
1032 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1033 |
|
@@ -1053,7 +1053,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|
1053 |
return self.model
|
1054 |
|
1055 |
# Ignore copy
|
1056 |
-
@add_start_docstrings_to_model_forward(
|
1057 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1058 |
def forward(
|
1059 |
self,
|
@@ -1081,10 +1081,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|
1081 |
Example:
|
1082 |
|
1083 |
```python
|
1084 |
-
>>> from transformers import AutoTokenizer,
|
1085 |
|
1086 |
-
>>> model =
|
1087 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google/
|
1088 |
|
1089 |
>>> prompt = "What is your favorite condiment?"
|
1090 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
@@ -1236,9 +1236,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|
1236 |
|
1237 |
@add_start_docstrings(
|
1238 |
"""
|
1239 |
-
The
|
1240 |
|
1241 |
-
[`
|
1242 |
(e.g. GPT-2) do.
|
1243 |
|
1244 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
@@ -1247,14 +1247,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|
1247 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1248 |
each row of the batch).
|
1249 |
""",
|
1250 |
-
|
1251 |
)
|
1252 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->
|
1253 |
-
class
|
1254 |
def __init__(self, config):
|
1255 |
super().__init__(config)
|
1256 |
self.num_labels = config.num_labels
|
1257 |
-
self.model =
|
1258 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1259 |
|
1260 |
# Initialize weights and apply final processing
|
@@ -1266,7 +1266,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|
1266 |
def set_input_embeddings(self, value):
|
1267 |
self.model.embed_tokens = value
|
1268 |
|
1269 |
-
@add_start_docstrings_to_model_forward(
|
1270 |
def forward(
|
1271 |
self,
|
1272 |
input_ids: torch.LongTensor = None,
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
+
""" PyTorch Gemmoe model."""
|
17 |
|
18 |
import math
|
19 |
import warnings
|
|
|
43 |
replace_return_docstrings,
|
44 |
)
|
45 |
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
+
from .configuration_gemmoe import GemmoeConfig
|
47 |
|
48 |
|
49 |
if is_flash_attn_2_available():
|
|
|
62 |
|
63 |
logger = logging.get_logger(__name__)
|
64 |
|
65 |
+
_CONFIG_FOR_DOC = "GemmoeConfig"
|
66 |
|
67 |
def approx_gelu(x):
|
68 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
|
|
79 |
)
|
80 |
|
81 |
|
82 |
+
class GemmoeRMSNorm(nn.Module):
|
83 |
def __init__(self, dim: int, eps: float = 1e-6):
|
84 |
super().__init__()
|
85 |
self.eps = eps
|
|
|
97 |
return normed_x * (self.weight + 1)
|
98 |
|
99 |
|
100 |
+
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
101 |
|
102 |
|
103 |
+
class GemmoeRotaryEmbedding(nn.Module):
|
104 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
105 |
super().__init__()
|
106 |
self.dim = dim
|
|
|
164 |
return q_embed, k_embed
|
165 |
|
166 |
|
167 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
|
168 |
+
class GemmoeMLP(nn.Module):
|
169 |
def __init__(self, config):
|
170 |
super().__init__()
|
171 |
self.config = config
|
|
|
193 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
194 |
|
195 |
|
196 |
+
class GemmoeAttention(nn.Module):
|
197 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
198 |
|
199 |
# Ignore copy
|
200 |
+
def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
|
201 |
super().__init__()
|
202 |
self.config = config
|
203 |
self.layer_idx = layer_idx
|
|
|
228 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
229 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
230 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
231 |
+
self.rotary_emb = GemmoeRotaryEmbedding(
|
232 |
self.head_dim,
|
233 |
max_position_embeddings=self.max_position_embeddings,
|
234 |
base=self.rope_theta,
|
|
|
298 |
return attn_output, attn_weights, past_key_value
|
299 |
|
300 |
|
301 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemmoe
|
302 |
+
class GemmoeFlashAttention2(GemmoeAttention):
|
303 |
"""
|
304 |
+
Gemmoe flash attention module. This module inherits from `GemmoeAttention` as the weights of the module stays
|
305 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
306 |
flash attention and deal with padding tokens in case the input contains any of them.
|
307 |
"""
|
|
|
363 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
364 |
# cast them back in the correct dtype just to be sure everything works as expected.
|
365 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
366 |
+
# in fp32. (GemmoeRMSNorm handles it correctly)
|
367 |
|
368 |
input_dtype = query_states.dtype
|
369 |
if input_dtype == torch.float32:
|
|
|
422 |
if not self._flash_attn_uses_top_left_mask:
|
423 |
causal = self.is_causal
|
424 |
else:
|
425 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmoeFlashAttention2 __init__.
|
426 |
causal = self.is_causal and query_length != 1
|
427 |
|
428 |
# Contains at least one padding token in the sequence
|
|
|
495 |
)
|
496 |
|
497 |
|
498 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
|
499 |
+
class GemmoeSdpaAttention(GemmoeAttention):
|
500 |
"""
|
501 |
+
Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
502 |
+
`GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
503 |
SDPA API.
|
504 |
"""
|
505 |
|
|
|
517 |
if output_attentions:
|
518 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
519 |
logger.warning_once(
|
520 |
+
"GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
521 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
522 |
)
|
523 |
return super().forward(
|
|
|
580 |
return attn_output, None, past_key_value
|
581 |
|
582 |
|
583 |
+
GEMMOE_ATTENTION_CLASSES = {
|
584 |
+
"eager": GemmoeAttention,
|
585 |
+
"flash_attention_2": GemmoeFlashAttention2,
|
586 |
+
"sdpa": GemmoeSdpaAttention,
|
587 |
}
|
588 |
|
589 |
|
590 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
|
591 |
+
class GemmoeDecoderLayer(nn.Module):
|
592 |
+
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
593 |
super().__init__()
|
594 |
self.hidden_size = config.hidden_size
|
595 |
|
596 |
+
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
597 |
|
598 |
+
self.mlp = GemmoeMLP(config)
|
599 |
+
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
600 |
+
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
601 |
|
602 |
def forward(
|
603 |
self,
|
|
|
663 |
return outputs
|
664 |
|
665 |
|
666 |
+
GEMMOE_START_DOCSTRING = r"""
|
667 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
668 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
669 |
etc.)
|
|
|
673 |
and behavior.
|
674 |
|
675 |
Parameters:
|
676 |
+
config ([`GemmoeConfig`]):
|
677 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
678 |
load the weights associated with the model, only the configuration. Check out the
|
679 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
681 |
|
682 |
|
683 |
@add_start_docstrings(
|
684 |
+
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
685 |
+
GEMMOE_START_DOCSTRING,
|
686 |
)
|
687 |
+
class GemmoePreTrainedModel(PreTrainedModel):
|
688 |
+
config_class = GemmoeConfig
|
689 |
base_model_prefix = "model"
|
690 |
supports_gradient_checkpointing = True
|
691 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
692 |
+
_no_split_modules = ["GemmoeDecoderLayer"]
|
693 |
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
694 |
_supports_flash_attn_2 = True
|
695 |
_supports_sdpa = True
|
|
|
728 |
layer.self_attn.past_key_value = None
|
729 |
|
730 |
|
731 |
+
GEMMOE_INPUTS_DOCSTRING = r"""
|
732 |
Args:
|
733 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
734 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
|
803 |
|
804 |
|
805 |
@add_start_docstrings(
|
806 |
+
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
807 |
+
GEMMOE_START_DOCSTRING,
|
808 |
)
|
809 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMOE,Llama->Gemmoe
|
810 |
+
class GemmoeModel(GemmoePreTrainedModel):
|
811 |
"""
|
812 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
|
813 |
|
814 |
Args:
|
815 |
+
config: GemmoeConfig
|
816 |
"""
|
817 |
|
818 |
+
def __init__(self, config: GemmoeConfig):
|
819 |
super().__init__(config)
|
820 |
self.padding_idx = config.pad_token_id
|
821 |
self.vocab_size = config.vocab_size
|
822 |
|
823 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
824 |
self.layers = nn.ModuleList(
|
825 |
+
[GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
826 |
)
|
827 |
+
self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
828 |
self.gradient_checkpointing = False
|
829 |
|
830 |
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
|
|
842 |
def set_input_embeddings(self, value):
|
843 |
self.embed_tokens = value
|
844 |
|
845 |
+
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
846 |
# Ignore copy
|
847 |
def forward(
|
848 |
self,
|
|
|
1021 |
return causal_mask
|
1022 |
|
1023 |
|
1024 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMOE,Llama->Gemmoe,llama->gemmoe
|
1025 |
+
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
1026 |
_tied_weights_keys = ["lm_head.weight"]
|
1027 |
|
1028 |
def __init__(self, config):
|
1029 |
super().__init__(config)
|
1030 |
+
self.model = GemmoeModel(config)
|
1031 |
self.vocab_size = config.vocab_size
|
1032 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1033 |
|
|
|
1053 |
return self.model
|
1054 |
|
1055 |
# Ignore copy
|
1056 |
+
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
1057 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1058 |
def forward(
|
1059 |
self,
|
|
|
1081 |
Example:
|
1082 |
|
1083 |
```python
|
1084 |
+
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1085 |
|
1086 |
+
>>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
|
1087 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
|
1088 |
|
1089 |
>>> prompt = "What is your favorite condiment?"
|
1090 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
1236 |
|
1237 |
@add_start_docstrings(
|
1238 |
"""
|
1239 |
+
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
1240 |
|
1241 |
+
[`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1242 |
(e.g. GPT-2) do.
|
1243 |
|
1244 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
1247 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1248 |
each row of the batch).
|
1249 |
""",
|
1250 |
+
GEMMOE_START_DOCSTRING,
|
1251 |
)
|
1252 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMOE,Llama->Gemmoe
|
1253 |
+
class GemmoeForSequenceClassification(GemmoePreTrainedModel):
|
1254 |
def __init__(self, config):
|
1255 |
super().__init__(config)
|
1256 |
self.num_labels = config.num_labels
|
1257 |
+
self.model = GemmoeModel(config)
|
1258 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1259 |
|
1260 |
# Initialize weights and apply final processing
|
|
|
1266 |
def set_input_embeddings(self, value):
|
1267 |
self.model.embed_tokens = value
|
1268 |
|
1269 |
+
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
1270 |
def forward(
|
1271 |
self,
|
1272 |
input_ids: torch.LongTensor = None,
|