Kaguya-19 commited on
Commit
3cc1148
·
verified ·
1 Parent(s): 414a512

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +42 -21
modeling_minicpm.py CHANGED
@@ -21,12 +21,16 @@
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
-
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
 
 
30
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
@@ -35,6 +39,7 @@ from transformers.modeling_attn_mask_utils import (
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
37
  _prepare_4d_causal_attention_mask_for_sdpa,
 
38
  )
39
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
  from transformers.modeling_utils import PreTrainedModel
@@ -320,9 +325,6 @@ class MiniCPMAttention(nn.Module):
320
  self.rope_theta = config.rope_theta
321
 
322
  self.is_causal = config.is_causal
323
-
324
- logger.info(f"self.is_causal = {self.is_causal}")
325
-
326
 
327
  if (self.head_dim * self.num_heads) != self.hidden_size:
328
  raise ValueError(
@@ -979,6 +981,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
979
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
 
981
  self.gradient_checkpointing = False
 
 
982
  # Initialize weights and apply final processing
983
  self.post_init()
984
 
@@ -1000,6 +1004,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1000
  output_attentions: Optional[bool] = None,
1001
  output_hidden_states: Optional[bool] = None,
1002
  return_dict: Optional[bool] = None,
 
1003
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1004
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1005
  output_hidden_states = (
@@ -1044,24 +1049,35 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1044
  inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1045
 
1046
  _attention_mask = attention_mask
1047
-
1048
  if self._use_flash_attention_2:
1049
  # 2d mask is passed through the layers
1050
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1051
  elif self._use_sdpa and not output_attentions:
1052
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1053
  # the manual implementation that requires a 4D causal mask in all cases.
1054
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1055
- attention_mask,
1056
- (batch_size, seq_length),
1057
- inputs_embeds,
1058
- past_key_values_length,
1059
- )
 
 
 
 
 
 
1060
  else:
1061
  # 4d mask is passed through the layers
1062
- attention_mask = _prepare_4d_causal_attention_mask(
1063
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1064
- )
 
 
 
 
 
 
1065
 
1066
  # embed positions
1067
  hidden_states = inputs_embeds
@@ -1109,14 +1125,18 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1109
  if output_hidden_states:
1110
  all_hidden_states += (hidden_states,)
1111
 
1112
- # gen weight before mean pooling
1113
- attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
1114
- s = hidden_states * attention_mask_.unsqueeze(-1).float()
1115
- d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
1116
 
1117
- hidden_states = s / d
 
 
 
 
 
 
 
 
1118
 
1119
- next_cache = None
1120
  if use_cache:
1121
  next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1122
  if not return_dict:
@@ -1127,7 +1147,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1127
  hidden_states=all_hidden_states,
1128
  attentions=all_self_attns,
1129
  )
1130
-
 
1131
 
1132
  class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1133
  _tied_weights_keys = ["lm_head.weight"]
 
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
+ import os
25
+ from tqdm import tqdm
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+ import numpy as np
32
+ from copy import deepcopy
33
+ from transformers import AutoTokenizer
34
 
35
  from transformers.activations import ACT2FN
36
  from transformers.cache_utils import Cache, DynamicCache
 
39
  _prepare_4d_attention_mask,
40
  _prepare_4d_causal_attention_mask,
41
  _prepare_4d_causal_attention_mask_for_sdpa,
42
+ _prepare_4d_attention_mask_for_sdpa,
43
  )
44
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
45
  from transformers.modeling_utils import PreTrainedModel
 
325
  self.rope_theta = config.rope_theta
326
 
327
  self.is_causal = config.is_causal
 
 
 
328
 
329
  if (self.head_dim * self.num_heads) != self.hidden_size:
330
  raise ValueError(
 
981
  self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
982
 
983
  self.gradient_checkpointing = False
984
+ self.is_causal = config.is_causal
985
+ self.adapt_mean_pooling = config.adapt_mean_pooling
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
 
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
1007
+ adapt_mean_pooling: Optional[bool] = None,
1008
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1009
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1010
  output_hidden_states = (
 
1049
  inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1050
 
1051
  _attention_mask = attention_mask
 
1052
  if self._use_flash_attention_2:
1053
  # 2d mask is passed through the layers
1054
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1055
  elif self._use_sdpa and not output_attentions:
1056
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1057
  # the manual implementation that requires a 4D causal mask in all cases.
1058
+ if self.is_causal:
1059
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
1060
+ attention_mask,
1061
+ (batch_size, seq_length),
1062
+ inputs_embeds,
1063
+ past_key_values_length,
1064
+ )
1065
+ else:
1066
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
1067
+ attention_mask,
1068
+ inputs_embeds.dtype,
1069
+ )
1070
  else:
1071
  # 4d mask is passed through the layers
1072
+ if self.is_causal:
1073
+ attention_mask = _prepare_4d_causal_attention_mask (
1074
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1075
+ )
1076
+ else:
1077
+ attention_mask = _prepare_4d_attention_mask(
1078
+ attention_mask,
1079
+ inputs_embeds.dtype,
1080
+ )
1081
 
1082
  # embed positions
1083
  hidden_states = inputs_embeds
 
1125
  if output_hidden_states:
1126
  all_hidden_states += (hidden_states,)
1127
 
1128
+ next_cache = None
 
 
 
1129
 
1130
+ # gen weight before mean pooling
1131
+ if adapt_mean_pooling is None:
1132
+ adapt_mean_pooling = self.adapt_mean_pooling
1133
+ if adapt_mean_pooling:
1134
+ attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
1135
+ s = hidden_states * attention_mask_.unsqueeze(-1).float()
1136
+ d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
1137
+
1138
+ hidden_states = s / d
1139
 
 
1140
  if use_cache:
1141
  next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1142
  if not return_dict:
 
1147
  hidden_states=all_hidden_states,
1148
  attentions=all_self_attns,
1149
  )
1150
+
1151
+
1152
 
1153
  class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1154
  _tied_weights_keys = ["lm_head.weight"]