Update modeling_minicpm.py
Browse files- modeling_minicpm.py +42 -21
modeling_minicpm.py
CHANGED
@@ -21,12 +21,16 @@
|
|
21 |
import math
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
-
|
|
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
|
|
|
30 |
|
31 |
from transformers.activations import ACT2FN
|
32 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -35,6 +39,7 @@ from transformers.modeling_attn_mask_utils import (
|
|
35 |
_prepare_4d_attention_mask,
|
36 |
_prepare_4d_causal_attention_mask,
|
37 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
38 |
)
|
39 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
40 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -320,9 +325,6 @@ class MiniCPMAttention(nn.Module):
|
|
320 |
self.rope_theta = config.rope_theta
|
321 |
|
322 |
self.is_causal = config.is_causal
|
323 |
-
|
324 |
-
logger.info(f"self.is_causal = {self.is_causal}")
|
325 |
-
|
326 |
|
327 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
328 |
raise ValueError(
|
@@ -979,6 +981,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
979 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
980 |
|
981 |
self.gradient_checkpointing = False
|
|
|
|
|
982 |
# Initialize weights and apply final processing
|
983 |
self.post_init()
|
984 |
|
@@ -1000,6 +1004,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
1000 |
output_attentions: Optional[bool] = None,
|
1001 |
output_hidden_states: Optional[bool] = None,
|
1002 |
return_dict: Optional[bool] = None,
|
|
|
1003 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1004 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1005 |
output_hidden_states = (
|
@@ -1044,24 +1049,35 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
1044 |
inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
|
1045 |
|
1046 |
_attention_mask = attention_mask
|
1047 |
-
|
1048 |
if self._use_flash_attention_2:
|
1049 |
# 2d mask is passed through the layers
|
1050 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1051 |
elif self._use_sdpa and not output_attentions:
|
1052 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1053 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1054 |
-
|
1055 |
-
attention_mask
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
1059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1060 |
else:
|
1061 |
# 4d mask is passed through the layers
|
1062 |
-
|
1063 |
-
attention_mask
|
1064 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1065 |
|
1066 |
# embed positions
|
1067 |
hidden_states = inputs_embeds
|
@@ -1109,14 +1125,18 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
1109 |
if output_hidden_states:
|
1110 |
all_hidden_states += (hidden_states,)
|
1111 |
|
1112 |
-
|
1113 |
-
attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
|
1114 |
-
s = hidden_states * attention_mask_.unsqueeze(-1).float()
|
1115 |
-
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
|
1116 |
|
1117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1118 |
|
1119 |
-
next_cache = None
|
1120 |
if use_cache:
|
1121 |
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
1122 |
if not return_dict:
|
@@ -1127,7 +1147,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
1127 |
hidden_states=all_hidden_states,
|
1128 |
attentions=all_self_attns,
|
1129 |
)
|
1130 |
-
|
|
|
1131 |
|
1132 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
1133 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
21 |
import math
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
+
import os
|
25 |
+
from tqdm import tqdm
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
28 |
import torch.utils.checkpoint
|
29 |
from torch import nn
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
+
import numpy as np
|
32 |
+
from copy import deepcopy
|
33 |
+
from transformers import AutoTokenizer
|
34 |
|
35 |
from transformers.activations import ACT2FN
|
36 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
39 |
_prepare_4d_attention_mask,
|
40 |
_prepare_4d_causal_attention_mask,
|
41 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
42 |
+
_prepare_4d_attention_mask_for_sdpa,
|
43 |
)
|
44 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
45 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
325 |
self.rope_theta = config.rope_theta
|
326 |
|
327 |
self.is_causal = config.is_causal
|
|
|
|
|
|
|
328 |
|
329 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
330 |
raise ValueError(
|
|
|
981 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
982 |
|
983 |
self.gradient_checkpointing = False
|
984 |
+
self.is_causal = config.is_causal
|
985 |
+
self.adapt_mean_pooling = config.adapt_mean_pooling
|
986 |
# Initialize weights and apply final processing
|
987 |
self.post_init()
|
988 |
|
|
|
1004 |
output_attentions: Optional[bool] = None,
|
1005 |
output_hidden_states: Optional[bool] = None,
|
1006 |
return_dict: Optional[bool] = None,
|
1007 |
+
adapt_mean_pooling: Optional[bool] = None,
|
1008 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1009 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1010 |
output_hidden_states = (
|
|
|
1049 |
inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
|
1050 |
|
1051 |
_attention_mask = attention_mask
|
|
|
1052 |
if self._use_flash_attention_2:
|
1053 |
# 2d mask is passed through the layers
|
1054 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1055 |
elif self._use_sdpa and not output_attentions:
|
1056 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1057 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1058 |
+
if self.is_causal:
|
1059 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
|
1060 |
+
attention_mask,
|
1061 |
+
(batch_size, seq_length),
|
1062 |
+
inputs_embeds,
|
1063 |
+
past_key_values_length,
|
1064 |
+
)
|
1065 |
+
else:
|
1066 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
1067 |
+
attention_mask,
|
1068 |
+
inputs_embeds.dtype,
|
1069 |
+
)
|
1070 |
else:
|
1071 |
# 4d mask is passed through the layers
|
1072 |
+
if self.is_causal:
|
1073 |
+
attention_mask = _prepare_4d_causal_attention_mask (
|
1074 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
1075 |
+
)
|
1076 |
+
else:
|
1077 |
+
attention_mask = _prepare_4d_attention_mask(
|
1078 |
+
attention_mask,
|
1079 |
+
inputs_embeds.dtype,
|
1080 |
+
)
|
1081 |
|
1082 |
# embed positions
|
1083 |
hidden_states = inputs_embeds
|
|
|
1125 |
if output_hidden_states:
|
1126 |
all_hidden_states += (hidden_states,)
|
1127 |
|
1128 |
+
next_cache = None
|
|
|
|
|
|
|
1129 |
|
1130 |
+
# gen weight before mean pooling
|
1131 |
+
if adapt_mean_pooling is None:
|
1132 |
+
adapt_mean_pooling = self.adapt_mean_pooling
|
1133 |
+
if adapt_mean_pooling:
|
1134 |
+
attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
|
1135 |
+
s = hidden_states * attention_mask_.unsqueeze(-1).float()
|
1136 |
+
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
|
1137 |
+
|
1138 |
+
hidden_states = s / d
|
1139 |
|
|
|
1140 |
if use_cache:
|
1141 |
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
1142 |
if not return_dict:
|
|
|
1147 |
hidden_states=all_hidden_states,
|
1148 |
attentions=all_self_attns,
|
1149 |
)
|
1150 |
+
|
1151 |
+
|
1152 |
|
1153 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
1154 |
_tied_weights_keys = ["lm_head.weight"]
|