duzx16
commited on
Commit
•
74d61a6
1
Parent(s):
4d01789
Add gradient checkpointing
Browse files- config.json +1 -1
- configuration_chatglm.py +5 -0
- modeling_chatglm.py +89 -15
- tokenization_chatglm.py +20 -3
config.json
CHANGED
@@ -36,5 +36,5 @@
|
|
36 |
"transformers_version": "4.27.1",
|
37 |
"tie_word_embeddings": false,
|
38 |
"eos_token_id": 2,
|
39 |
-
"pad_token_id":
|
40 |
}
|
|
|
36 |
"transformers_version": "4.27.1",
|
37 |
"tie_word_embeddings": false,
|
38 |
"eos_token_id": 2,
|
39 |
+
"pad_token_id": 0
|
40 |
}
|
configuration_chatglm.py
CHANGED
@@ -28,9 +28,12 @@ class ChatGLMConfig(PretrainedConfig):
|
|
28 |
attention_softmax_in_fp32=True,
|
29 |
fp32_residual_connection=False,
|
30 |
quantization_bit=0,
|
|
|
|
|
31 |
**kwargs
|
32 |
):
|
33 |
self.num_layers = num_layers
|
|
|
34 |
self.padded_vocab_size = padded_vocab_size
|
35 |
self.hidden_size = hidden_size
|
36 |
self.ffn_hidden_size = ffn_hidden_size
|
@@ -52,4 +55,6 @@ class ChatGLMConfig(PretrainedConfig):
|
|
52 |
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
53 |
self.fp32_residual_connection = fp32_residual_connection
|
54 |
self.quantization_bit = quantization_bit
|
|
|
|
|
55 |
super().__init__(**kwargs)
|
|
|
28 |
attention_softmax_in_fp32=True,
|
29 |
fp32_residual_connection=False,
|
30 |
quantization_bit=0,
|
31 |
+
pre_seq_len=None,
|
32 |
+
prefix_projection=False,
|
33 |
**kwargs
|
34 |
):
|
35 |
self.num_layers = num_layers
|
36 |
+
self.vocab_size = padded_vocab_size
|
37 |
self.padded_vocab_size = padded_vocab_size
|
38 |
self.hidden_size = hidden_size
|
39 |
self.ffn_hidden_size = ffn_hidden_size
|
|
|
55 |
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
56 |
self.fp32_residual_connection = fp32_residual_connection
|
57 |
self.quantization_bit = quantization_bit
|
58 |
+
self.pre_seq_len = pre_seq_len
|
59 |
+
self.prefix_projection = prefix_projection
|
60 |
super().__init__(**kwargs)
|
modeling_chatglm.py
CHANGED
@@ -56,6 +56,36 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
56 |
return scores
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def split_tensor_along_last_dim(
|
60 |
tensor: torch.Tensor,
|
61 |
num_partitions: int,
|
@@ -566,6 +596,8 @@ class GLMTransformer(torch.nn.Module):
|
|
566 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
567 |
dtype=config.torch_dtype)
|
568 |
|
|
|
|
|
569 |
def _get_layer(self, layer_number):
|
570 |
return self.layers[layer_number]
|
571 |
|
@@ -577,6 +609,13 @@ class GLMTransformer(torch.nn.Module):
|
|
577 |
if not kv_caches:
|
578 |
kv_caches = [None for _ in range(self.num_layers)]
|
579 |
presents = () if use_cache else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
all_self_attentions = None
|
581 |
all_hidden_states = () if output_hidden_states else None
|
582 |
for index in range(self.num_layers):
|
@@ -584,14 +623,24 @@ class GLMTransformer(torch.nn.Module):
|
|
584 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
585 |
|
586 |
layer = self._get_layer(index)
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
if use_cache:
|
596 |
presents = presents + (kv_cache,)
|
597 |
|
@@ -645,7 +694,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
645 |
return position_ids
|
646 |
|
647 |
def _set_gradient_checkpointing(self, module, value=False):
|
648 |
-
if isinstance(module,
|
649 |
module.gradient_checkpointing = value
|
650 |
|
651 |
|
@@ -700,11 +749,33 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
700 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
701 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
704 |
|
705 |
def get_input_embeddings(self):
|
706 |
return self.embedding.word_embeddings
|
707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
def forward(
|
709 |
self,
|
710 |
input_ids,
|
@@ -740,6 +811,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
741 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
742 |
|
|
|
|
|
|
|
|
|
|
|
743 |
# Run encoder.
|
744 |
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
745 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
@@ -913,10 +989,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
913 |
return response
|
914 |
|
915 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
916 |
-
prompt =
|
917 |
-
for i, (old_query, response) in enumerate(history):
|
918 |
-
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
919 |
-
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
920 |
inputs = tokenizer([prompt], return_tensors="pt")
|
921 |
inputs = inputs.to(self.device)
|
922 |
return inputs
|
@@ -933,7 +1006,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
933 |
inputs = inputs.to(self.device)
|
934 |
return inputs
|
935 |
|
936 |
-
|
937 |
@torch.no_grad()
|
938 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
939 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
@@ -969,6 +1041,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
969 |
inputs = self.build_stream_inputs(tokenizer, query, history=history)
|
970 |
if past_key_values is not None:
|
971 |
past_length = past_key_values[0][0].shape[0]
|
|
|
|
|
972 |
inputs.position_ids += past_length
|
973 |
attention_mask = inputs.attention_mask
|
974 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
56 |
return scores
|
57 |
|
58 |
|
59 |
+
class PrefixEncoder(torch.nn.Module):
|
60 |
+
"""
|
61 |
+
The torch.nn model to encode the prefix
|
62 |
+
Input shape: (batch-size, prefix-length)
|
63 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, config):
|
67 |
+
super().__init__()
|
68 |
+
self.prefix_projection = config.prefix_projection
|
69 |
+
if self.prefix_projection:
|
70 |
+
# Use a two-layer MLP to encode the prefix
|
71 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
|
72 |
+
self.trans = torch.nn.Sequential(
|
73 |
+
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
74 |
+
torch.nn.Tanh(),
|
75 |
+
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
79 |
+
|
80 |
+
def forward(self, prefix: torch.Tensor):
|
81 |
+
if self.prefix_projection:
|
82 |
+
prefix_tokens = self.embedding(prefix)
|
83 |
+
past_key_values = self.trans(prefix_tokens)
|
84 |
+
else:
|
85 |
+
past_key_values = self.embedding(prefix)
|
86 |
+
return past_key_values
|
87 |
+
|
88 |
+
|
89 |
def split_tensor_along_last_dim(
|
90 |
tensor: torch.Tensor,
|
91 |
num_partitions: int,
|
|
|
596 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
597 |
dtype=config.torch_dtype)
|
598 |
|
599 |
+
self.gradient_checkpointing = False
|
600 |
+
|
601 |
def _get_layer(self, layer_number):
|
602 |
return self.layers[layer_number]
|
603 |
|
|
|
609 |
if not kv_caches:
|
610 |
kv_caches = [None for _ in range(self.num_layers)]
|
611 |
presents = () if use_cache else None
|
612 |
+
if self.gradient_checkpointing and self.training:
|
613 |
+
if use_cache:
|
614 |
+
logger.warning_once(
|
615 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
616 |
+
)
|
617 |
+
use_cache = False
|
618 |
+
|
619 |
all_self_attentions = None
|
620 |
all_hidden_states = () if output_hidden_states else None
|
621 |
for index in range(self.num_layers):
|
|
|
623 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
624 |
|
625 |
layer = self._get_layer(index)
|
626 |
+
if self.gradient_checkpointing and self.training:
|
627 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
628 |
+
layer,
|
629 |
+
hidden_states,
|
630 |
+
attention_mask,
|
631 |
+
rotary_pos_emb,
|
632 |
+
kv_cache=kv_caches[index],
|
633 |
+
use_cache=use_cache
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
layer_ret = layer(
|
637 |
+
hidden_states,
|
638 |
+
attention_mask,
|
639 |
+
rotary_pos_emb,
|
640 |
+
kv_cache=kv_caches[index],
|
641 |
+
use_cache=use_cache
|
642 |
+
)
|
643 |
+
hidden_states, kv_cache = layer_ret
|
644 |
if use_cache:
|
645 |
presents = presents + (kv_cache,)
|
646 |
|
|
|
694 |
return position_ids
|
695 |
|
696 |
def _set_gradient_checkpointing(self, module, value=False):
|
697 |
+
if isinstance(module, GLMTransformer):
|
698 |
module.gradient_checkpointing = value
|
699 |
|
700 |
|
|
|
749 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
750 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
751 |
dtype=config.torch_dtype, **init_kwargs)
|
752 |
+
self.pre_seq_len = config.pre_seq_len
|
753 |
+
self.prefix_projection = config.prefix_projection
|
754 |
+
if self.pre_seq_len is not None:
|
755 |
+
for param in self.parameters():
|
756 |
+
param.requires_grad = False
|
757 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
758 |
+
self.prefix_encoder = PrefixEncoder(config)
|
759 |
+
self.dropout = torch.nn.Dropout(0.1)
|
760 |
|
761 |
def get_input_embeddings(self):
|
762 |
return self.embedding.word_embeddings
|
763 |
|
764 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
765 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
766 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
767 |
+
past_key_values = past_key_values.view(
|
768 |
+
batch_size,
|
769 |
+
self.pre_seq_len,
|
770 |
+
self.num_layers * 2,
|
771 |
+
self.num_attention_heads,
|
772 |
+
self.hidden_size // self.num_attention_heads
|
773 |
+
)
|
774 |
+
# seq_len, b, nh, hidden_size
|
775 |
+
past_key_values = self.dropout(past_key_values)
|
776 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
777 |
+
return past_key_values
|
778 |
+
|
779 |
def forward(
|
780 |
self,
|
781 |
input_ids,
|
|
|
811 |
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
812 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
813 |
|
814 |
+
if past_key_values is None:
|
815 |
+
if self.pre_seq_len is not None:
|
816 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
817 |
+
dtype=inputs_embeds.dtype)
|
818 |
+
|
819 |
# Run encoder.
|
820 |
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
821 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
|
|
989 |
return response
|
990 |
|
991 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
992 |
+
prompt = tokenizer.build_prompt(query, history=history)
|
|
|
|
|
|
|
993 |
inputs = tokenizer([prompt], return_tensors="pt")
|
994 |
inputs = inputs.to(self.device)
|
995 |
return inputs
|
|
|
1006 |
inputs = inputs.to(self.device)
|
1007 |
return inputs
|
1008 |
|
|
|
1009 |
@torch.no_grad()
|
1010 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
1011 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
|
|
1041 |
inputs = self.build_stream_inputs(tokenizer, query, history=history)
|
1042 |
if past_key_values is not None:
|
1043 |
past_length = past_key_values[0][0].shape[0]
|
1044 |
+
if self.transformer.pre_seq_len is not None:
|
1045 |
+
past_length -= self.transformer.pre_seq_len
|
1046 |
inputs.position_ids += past_length
|
1047 |
attention_mask = inputs.attention_mask
|
1048 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
tokenization_chatglm.py
CHANGED
@@ -17,7 +17,7 @@ class SPTokenizer:
|
|
17 |
self.n_words: int = self.sp_model.vocab_size()
|
18 |
self.bos_id: int = self.sp_model.bos_id()
|
19 |
self.eos_id: int = self.sp_model.eos_id()
|
20 |
-
self.pad_id: int = self.sp_model.
|
21 |
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
|
23 |
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
@@ -55,7 +55,7 @@ class SPTokenizer:
|
|
55 |
|
56 |
def convert_id_to_token(self, index):
|
57 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
-
if index in self.index_special_tokens:
|
59 |
return ""
|
60 |
return self.sp_model.IdToPiece(index)
|
61 |
|
@@ -85,12 +85,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
85 |
|
86 |
@property
|
87 |
def pad_token(self) -> str:
|
88 |
-
return "
|
89 |
|
90 |
@property
|
91 |
def pad_token_id(self):
|
92 |
return self.get_command("<pad>")
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
@property
|
95 |
def vocab_size(self):
|
96 |
return self.tokenizer.n_words
|
@@ -147,6 +155,15 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
147 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
148 |
return prefix_tokens
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
def build_inputs_with_special_tokens(
|
151 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
152 |
) -> List[int]:
|
|
|
17 |
self.n_words: int = self.sp_model.vocab_size()
|
18 |
self.bos_id: int = self.sp_model.bos_id()
|
19 |
self.eos_id: int = self.sp_model.eos_id()
|
20 |
+
self.pad_id: int = self.sp_model.unk_id()
|
21 |
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
|
23 |
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
|
|
55 |
|
56 |
def convert_id_to_token(self, index):
|
57 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
+
if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
59 |
return ""
|
60 |
return self.sp_model.IdToPiece(index)
|
61 |
|
|
|
85 |
|
86 |
@property
|
87 |
def pad_token(self) -> str:
|
88 |
+
return "<unk>"
|
89 |
|
90 |
@property
|
91 |
def pad_token_id(self):
|
92 |
return self.get_command("<pad>")
|
93 |
|
94 |
+
@property
|
95 |
+
def eos_token(self) -> str:
|
96 |
+
return "</s>"
|
97 |
+
|
98 |
+
@property
|
99 |
+
def eos_token_id(self):
|
100 |
+
return self.get_command("<eos>")
|
101 |
+
|
102 |
@property
|
103 |
def vocab_size(self):
|
104 |
return self.tokenizer.n_words
|
|
|
155 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
156 |
return prefix_tokens
|
157 |
|
158 |
+
def build_prompt(self, query, history=None):
|
159 |
+
if history is None:
|
160 |
+
history = []
|
161 |
+
prompt = ""
|
162 |
+
for i, (old_query, response) in enumerate(history):
|
163 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
164 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
165 |
+
return prompt
|
166 |
+
|
167 |
def build_inputs_with_special_tokens(
|
168 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
169 |
) -> List[int]:
|