Crystalcareai commited on
Commit
81c9359
1 Parent(s): 8f32857

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +5 -10
modeling_gemmoe.py CHANGED
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
65
  _CONFIG_FOR_DOC = "GemmoeConfig"
66
 
67
  def approx_gelu(x):
68
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
69
 
70
  def _get_unpad_data(attention_mask):
71
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -348,9 +348,9 @@ class GemmoeFlashAttention2(GemmoeAttention):
348
  f" {target_dtype}."
349
  )
350
 
351
- query_states = query_states.to(target_dtype)
352
- key_states = key_states.to(target_dtype)
353
- value_states = value_states.to(target_dtype)
354
 
355
  attn_output = self._flash_attention_forward(
356
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
@@ -845,12 +845,7 @@ class GemmoeModel(GemmoePreTrainedModel):
845
  if inputs_embeds is None:
846
  inputs_embeds = self.embed_tokens(input_ids)
847
  # Scale embeddings
848
- # Fix for precision issue when casting to bfloat16
849
- hidden_size_sqrt = math.sqrt(self.config.hidden_size)
850
- if inputs_embeds.dtype == torch.bfloat16:
851
- pass
852
-
853
- hidden_states = inputs_embeds * hidden_size_sqrt
854
 
855
  past_seen_tokens = 0
856
  if use_cache: # kept for BC (cache positions)
 
65
  _CONFIG_FOR_DOC = "GemmoeConfig"
66
 
67
  def approx_gelu(x):
68
+ return x * torch.sigmoid(1.702 * x)
69
 
70
  def _get_unpad_data(attention_mask):
71
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
348
  f" {target_dtype}."
349
  )
350
 
351
+ query_states = query_states.to(target_dtype, non_blocking=True)
352
+ key_states = key_states.to(target_dtype, non_blocking=True)
353
+ value_states = value_states.to(target_dtype, non_blocking=True)
354
 
355
  attn_output = self._flash_attention_forward(
356
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
845
  if inputs_embeds is None:
846
  inputs_embeds = self.embed_tokens(input_ids)
847
  # Scale embeddings
848
+ hidden_states = inputs_embeds * (self.config.hidden_size ** 0.5)
 
 
 
 
 
849
 
850
  past_seen_tokens = 0
851
  if use_cache: # kept for BC (cache positions)