Crystalcareai
commited on
Commit
•
81c9359
1
Parent(s):
8f32857
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +5 -10
modeling_gemmoe.py
CHANGED
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
|
|
65 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
66 |
|
67 |
def approx_gelu(x):
|
68 |
-
return
|
69 |
|
70 |
def _get_unpad_data(attention_mask):
|
71 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -348,9 +348,9 @@ class GemmoeFlashAttention2(GemmoeAttention):
|
|
348 |
f" {target_dtype}."
|
349 |
)
|
350 |
|
351 |
-
query_states = query_states.to(target_dtype)
|
352 |
-
key_states = key_states.to(target_dtype)
|
353 |
-
value_states = value_states.to(target_dtype)
|
354 |
|
355 |
attn_output = self._flash_attention_forward(
|
356 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
@@ -845,12 +845,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
845 |
if inputs_embeds is None:
|
846 |
inputs_embeds = self.embed_tokens(input_ids)
|
847 |
# Scale embeddings
|
848 |
-
|
849 |
-
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
850 |
-
if inputs_embeds.dtype == torch.bfloat16:
|
851 |
-
pass
|
852 |
-
|
853 |
-
hidden_states = inputs_embeds * hidden_size_sqrt
|
854 |
|
855 |
past_seen_tokens = 0
|
856 |
if use_cache: # kept for BC (cache positions)
|
|
|
65 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
66 |
|
67 |
def approx_gelu(x):
|
68 |
+
return x * torch.sigmoid(1.702 * x)
|
69 |
|
70 |
def _get_unpad_data(attention_mask):
|
71 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
348 |
f" {target_dtype}."
|
349 |
)
|
350 |
|
351 |
+
query_states = query_states.to(target_dtype, non_blocking=True)
|
352 |
+
key_states = key_states.to(target_dtype, non_blocking=True)
|
353 |
+
value_states = value_states.to(target_dtype, non_blocking=True)
|
354 |
|
355 |
attn_output = self._flash_attention_forward(
|
356 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
|
|
845 |
if inputs_embeds is None:
|
846 |
inputs_embeds = self.embed_tokens(input_ids)
|
847 |
# Scale embeddings
|
848 |
+
hidden_states = inputs_embeds * (self.config.hidden_size ** 0.5)
|
|
|
|
|
|
|
|
|
|
|
849 |
|
850 |
past_seen_tokens = 0
|
851 |
if use_cache: # kept for BC (cache positions)
|