shwu
commited on
Commit
•
769e287
1
Parent(s):
d1844e4
feat: better modeling_blip2chatglm
Browse files- config.json +5 -5
- configuration_blip2chatglm.py +1 -1
- generation_config.json +4 -0
- pytorch_model.bin → ice_text.model +2 -2
- modeling_blip2chatglm.py +244 -41
- modeling_chatglm.py +82 -54
- preprocessor_config.json +24 -0
- pytorch_model-00001-of-00009.bin +3 -0
- pytorch_model-00002-of-00009.bin +3 -0
- pytorch_model-00003-of-00009.bin +3 -0
- pytorch_model-00004-of-00009.bin +3 -0
- pytorch_model-00005-of-00009.bin +3 -0
- pytorch_model-00006-of-00009.bin +3 -0
- pytorch_model-00007-of-00009.bin +3 -0
- pytorch_model-00008-of-00009.bin +3 -0
- pytorch_model-00009-of-00009.bin +3 -0
- pytorch_model.bin.index.json +0 -0
- special_tokens_map.json +7 -0
- tokenization_chatglm.py +433 -0
- tokenizer_config.json +23 -0
config.json
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
{
|
2 |
"_commit_hash": null,
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
"initializer_factor": 1.0,
|
7 |
"initializer_range": 0.02,
|
@@ -174,7 +174,7 @@
|
|
174 |
"tie_word_embeddings": false,
|
175 |
"torch_dtype": "float32",
|
176 |
"transformers_version": null,
|
177 |
-
"use_decoder_only_language_model":
|
178 |
"vision_config": {
|
179 |
"_name_or_path": "",
|
180 |
"add_cross_attention": false,
|
@@ -248,7 +248,7 @@
|
|
248 |
"tokenizer_class": null,
|
249 |
"top_k": 50,
|
250 |
"top_p": 1.0,
|
251 |
-
"torch_dtype":
|
252 |
"torchscript": false,
|
253 |
"transformers_version": "4.27.3",
|
254 |
"typical_p": 1.0,
|
@@ -256,7 +256,7 @@
|
|
256 |
},
|
257 |
"auto_map": {
|
258 |
"AutoConfig": "configuration_blip2chatglm.Blip2ChatGLMConfig",
|
259 |
-
"AutoModel": "modeling_blip2chatglm.
|
260 |
-
"AutoModelForCausalLM": "modeling_blip2chatglm.
|
261 |
}
|
262 |
}
|
|
|
1 |
{
|
2 |
"_commit_hash": null,
|
3 |
"architectures": [
|
4 |
+
"Blip2ChatGLMForConditionalGeneration"
|
5 |
],
|
6 |
"initializer_factor": 1.0,
|
7 |
"initializer_range": 0.02,
|
|
|
174 |
"tie_word_embeddings": false,
|
175 |
"torch_dtype": "float32",
|
176 |
"transformers_version": null,
|
177 |
+
"use_decoder_only_language_model": true,
|
178 |
"vision_config": {
|
179 |
"_name_or_path": "",
|
180 |
"add_cross_attention": false,
|
|
|
248 |
"tokenizer_class": null,
|
249 |
"top_k": 50,
|
250 |
"top_p": 1.0,
|
251 |
+
"torch_dtype": "float16",
|
252 |
"torchscript": false,
|
253 |
"transformers_version": "4.27.3",
|
254 |
"typical_p": 1.0,
|
|
|
256 |
},
|
257 |
"auto_map": {
|
258 |
"AutoConfig": "configuration_blip2chatglm.Blip2ChatGLMConfig",
|
259 |
+
"AutoModel": "modeling_blip2chatglm.Blip2ChatGLMForConditionalGeneration",
|
260 |
+
"AutoModelForCausalLM": "modeling_blip2chatglm.Blip2ChatGLMForConditionalGeneration"
|
261 |
}
|
262 |
}
|
configuration_blip2chatglm.py
CHANGED
@@ -49,7 +49,7 @@ class Blip2ChatGLMConfig(PretrainedConfig):
|
|
49 |
self.num_query_tokens = num_query_tokens
|
50 |
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
51 |
# self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
52 |
-
self.use_decoder_only_language_model =
|
53 |
self.initializer_factor = 1.0
|
54 |
self.initializer_range = 0.02
|
55 |
|
|
|
49 |
self.num_query_tokens = num_query_tokens
|
50 |
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
51 |
# self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
52 |
+
self.use_decoder_only_language_model = True # chatglm has no encoder
|
53 |
self.initializer_factor = 1.0
|
54 |
self.initializer_range = 0.02
|
55 |
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.27.3"
|
4 |
+
}
|
pytorch_model.bin → ice_text.model
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e974d9a69c242ce014c88c2b26089270f6198f3c0b700a887666cd3e816f17e
|
3 |
+
size 2706249
|
modeling_blip2chatglm.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import copy
|
|
|
2 |
from typing import Callable, List, Optional, Tuple, Union
|
3 |
import torch
|
|
|
4 |
import warnings
|
5 |
from torch import Tensor, nn
|
6 |
|
@@ -8,8 +10,14 @@ from transformers import (
|
|
8 |
PreTrainedModel,
|
9 |
Blip2VisionModel,
|
10 |
Blip2QFormerModel,
|
|
|
|
|
|
|
11 |
GenerationConfig,
|
12 |
)
|
|
|
|
|
|
|
13 |
from transformers.utils import logging
|
14 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
15 |
|
@@ -23,9 +31,13 @@ from .configuration_blip2chatglm import Blip2ChatGLMConfig
|
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
25 |
|
26 |
-
class
|
|
|
|
|
27 |
def __init__(self, config: Blip2ChatGLMConfig):
|
28 |
-
|
|
|
|
|
29 |
|
30 |
self.vision_model = Blip2VisionModel(config.vision_config)
|
31 |
|
@@ -37,21 +49,65 @@ class Blip2ForChatGLM(PreTrainedModel):
|
|
37 |
self.language_projection = nn.Linear(
|
38 |
config.qformer_config.hidden_size, config.text_config.hidden_size
|
39 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def forward(
|
42 |
self,
|
43 |
pixel_values: torch.FloatTensor,
|
|
|
|
|
|
|
44 |
output_attentions: Optional[bool] = None,
|
45 |
output_hidden_states: Optional[bool] = None,
|
|
|
46 |
return_dict: Optional[bool] = None,
|
47 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
return_dict = (
|
49 |
return_dict if return_dict is not None else self.config.use_return_dict
|
50 |
)
|
51 |
|
52 |
# step 1: forward the images through the vision encoder,
|
53 |
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
54 |
-
vision_outputs = self.vision_model
|
55 |
pixel_values=pixel_values,
|
56 |
output_attentions=output_attentions,
|
57 |
output_hidden_states=output_hidden_states,
|
@@ -65,7 +121,7 @@ class Blip2ForChatGLM(PreTrainedModel):
|
|
65 |
)
|
66 |
|
67 |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
68 |
-
query_outputs = self.qformer
|
69 |
query_embeds=query_tokens,
|
70 |
encoder_hidden_states=image_embeds,
|
71 |
encoder_attention_mask=image_attention_mask,
|
@@ -76,23 +132,54 @@ class Blip2ForChatGLM(PreTrainedModel):
|
|
76 |
query_output = query_outputs[0]
|
77 |
|
78 |
# step 3: use the language model, conditioned on the query outputs and the prompt
|
79 |
-
language_model_inputs = self.language_projection
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
self.blip2 = blip2
|
95 |
-
self.language = lm
|
96 |
|
97 |
@torch.no_grad()
|
98 |
def stream_chat(
|
@@ -106,12 +193,12 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
106 |
do_sample=True,
|
107 |
temperature=1,
|
108 |
):
|
109 |
-
device = self.
|
110 |
# 1. Prepare token ids
|
111 |
images = []
|
112 |
image_slots = []
|
113 |
|
114 |
-
nvtokens = self.
|
115 |
if history:
|
116 |
input_ids = tokenizer(
|
117 |
f"[Round {len(history)}]\n问:", add_special_tokens=False
|
@@ -181,27 +268,27 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
181 |
# 2. Prepare image embeddings
|
182 |
if len(images) != 0:
|
183 |
image = torch.cat(list(images), dim=0)
|
184 |
-
vision_outputs = self.
|
185 |
image_embeds = vision_outputs[0]
|
186 |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
187 |
device
|
188 |
)
|
189 |
|
190 |
-
query_tokens = self.
|
191 |
-
query_outputs = self.
|
192 |
query_embeds=query_tokens,
|
193 |
encoder_hidden_states=image_embeds,
|
194 |
encoder_attention_mask=image_atts,
|
195 |
)
|
196 |
query_output = query_outputs[0]
|
197 |
|
198 |
-
vtokens = self.
|
199 |
else:
|
200 |
vtokens = []
|
201 |
|
202 |
# 3. Place image embeddings into slots
|
203 |
input_ids = torch.as_tensor(input_ids, dtype=torch.long).to(device).unsqueeze(0)
|
204 |
-
inputs_embeds = self.
|
205 |
for slot, vimg in zip(image_slots, vtokens):
|
206 |
inputs_embeds[0][-slot : -slot + nvtokens, :] = vimg
|
207 |
|
@@ -216,17 +303,16 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
216 |
"logits_processor": logits_processor,
|
217 |
}
|
218 |
|
219 |
-
for outputs in self.
|
220 |
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
|
221 |
):
|
222 |
outputs = outputs.tolist()[0][len(input_ids[0]) :]
|
223 |
response = tokenizer.decode(outputs)
|
224 |
-
response = self.
|
225 |
-
|
226 |
-
yield response, new_history
|
227 |
|
228 |
@torch.no_grad()
|
229 |
-
def
|
230 |
self,
|
231 |
input_ids,
|
232 |
inputs_embeds,
|
@@ -238,10 +324,23 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
238 |
] = None,
|
239 |
**kwargs,
|
240 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
242 |
|
243 |
if generation_config is None:
|
244 |
-
generation_config = self.
|
245 |
generation_config = copy.deepcopy(generation_config)
|
246 |
model_kwargs = generation_config.update(**kwargs)
|
247 |
bos_token_id, eos_token_id = (
|
@@ -279,7 +378,7 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
279 |
if input_ids_seq_length >= generation_config.max_length:
|
280 |
input_ids_string = (
|
281 |
"decoder_input_ids"
|
282 |
-
if self.
|
283 |
else "input_ids"
|
284 |
)
|
285 |
logger.warning(
|
@@ -298,7 +397,7 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
298 |
else StoppingCriteriaList()
|
299 |
)
|
300 |
|
301 |
-
logits_processor = self.
|
302 |
generation_config=generation_config,
|
303 |
input_ids_seq_length=input_ids_seq_length,
|
304 |
encoder_input_ids=input_ids,
|
@@ -306,19 +405,19 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
306 |
logits_processor=logits_processor,
|
307 |
)
|
308 |
|
309 |
-
stopping_criteria = self.
|
310 |
generation_config=generation_config, stopping_criteria=stopping_criteria
|
311 |
)
|
312 |
-
logits_warper = self.
|
313 |
|
314 |
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
315 |
scores = None
|
316 |
while True:
|
317 |
-
model_inputs = self.
|
318 |
input_ids, inputs_embeds=inputs_embeds, **model_kwargs
|
319 |
)
|
320 |
# forward pass to get next token
|
321 |
-
outputs = self.
|
322 |
**model_inputs,
|
323 |
return_dict=True,
|
324 |
output_attentions=False,
|
@@ -343,14 +442,14 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
343 |
inputs_embeds = torch.cat(
|
344 |
[
|
345 |
inputs_embeds,
|
346 |
-
self.
|
347 |
],
|
348 |
dim=1,
|
349 |
)
|
350 |
-
model_kwargs = self.
|
351 |
outputs,
|
352 |
model_kwargs,
|
353 |
-
is_encoder_decoder=self.
|
354 |
)
|
355 |
unfinished_sequences = unfinished_sequences.mul(
|
356 |
(sum(next_tokens != i for i in eos_token_id)).long()
|
@@ -360,3 +459,107 @@ class Blip2ChatGLM(PreTrainedModel):
|
|
360 |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
361 |
break
|
362 |
yield input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import copy
|
2 |
+
import os
|
3 |
from typing import Callable, List, Optional, Tuple, Union
|
4 |
import torch
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
import warnings
|
7 |
from torch import Tensor, nn
|
8 |
|
|
|
10 |
PreTrainedModel,
|
11 |
Blip2VisionModel,
|
12 |
Blip2QFormerModel,
|
13 |
+
Blip2Model,
|
14 |
+
Blip2PreTrainedModel,
|
15 |
+
Blip2ForConditionalGeneration,
|
16 |
GenerationConfig,
|
17 |
)
|
18 |
+
from transformers.models.blip_2.modeling_blip_2 import (
|
19 |
+
Blip2ForConditionalGenerationModelOutput,
|
20 |
+
)
|
21 |
from transformers.utils import logging
|
22 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
23 |
|
|
|
31 |
logger = logging.get_logger(__name__)
|
32 |
|
33 |
|
34 |
+
class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
35 |
+
config_class = Blip2ChatGLMConfig
|
36 |
+
|
37 |
def __init__(self, config: Blip2ChatGLMConfig):
|
38 |
+
Blip2PreTrainedModel.__init__(self, config)
|
39 |
+
# NOTE: we only initialize Blip2PreTrainedModel
|
40 |
+
# directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel
|
41 |
|
42 |
self.vision_model = Blip2VisionModel(config.vision_config)
|
43 |
|
|
|
49 |
self.language_projection = nn.Linear(
|
50 |
config.qformer_config.hidden_size, config.text_config.hidden_size
|
51 |
)
|
52 |
+
self.language_model = ChatGLMForConditionalGeneration(config.text_config)
|
53 |
+
|
54 |
+
# Initialize weights and apply final processing
|
55 |
+
# self.post_init()
|
56 |
+
|
57 |
+
def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"):
|
58 |
+
if vision_encoder_dtype == "fp32":
|
59 |
+
self.vision_model = self.vision_model.float()
|
60 |
+
elif vision_encoder_dtype == "fp16":
|
61 |
+
self.vision_model = self.vision_model.half()
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(
|
64 |
+
f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}"
|
65 |
+
)
|
66 |
+
|
67 |
+
if lm_dtype == "fp32":
|
68 |
+
self.language_model = self.language_model.float()
|
69 |
+
elif lm_dtype == "fp16":
|
70 |
+
self.language_model = self.language_model.half()
|
71 |
+
elif lm_dtype == "int4":
|
72 |
+
self.language_model = self.language_model.half().quantize(4)
|
73 |
+
elif lm_dtype == "int8":
|
74 |
+
self.language_model = self.language_model.half().quantize(8)
|
75 |
+
else:
|
76 |
+
raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}")
|
77 |
|
78 |
def forward(
|
79 |
self,
|
80 |
pixel_values: torch.FloatTensor,
|
81 |
+
input_ids: torch.FloatTensor,
|
82 |
+
image_slot_offset: Optional[torch.LongTensor] = None,
|
83 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
84 |
output_attentions: Optional[bool] = None,
|
85 |
output_hidden_states: Optional[bool] = None,
|
86 |
+
labels: Optional[torch.LongTensor] = None,
|
87 |
return_dict: Optional[bool] = None,
|
88 |
+
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
89 |
+
"""_summary_
|
90 |
+
|
91 |
+
Args:
|
92 |
+
pixel_values (torch.FloatTensor): _description_
|
93 |
+
input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id
|
94 |
+
image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None.
|
95 |
+
attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None.
|
96 |
+
output_attentions (Optional[bool], optional): _description_. Defaults to None.
|
97 |
+
output_hidden_states (Optional[bool], optional): _description_. Defaults to None.
|
98 |
+
labels (Optional[torch.LongTensor], optional): _description_. Defaults to None.
|
99 |
+
return_dict (Optional[bool], optional): _description_. Defaults to None.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_
|
103 |
+
"""
|
104 |
return_dict = (
|
105 |
return_dict if return_dict is not None else self.config.use_return_dict
|
106 |
)
|
107 |
|
108 |
# step 1: forward the images through the vision encoder,
|
109 |
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
110 |
+
vision_outputs = self.vision_model(
|
111 |
pixel_values=pixel_values,
|
112 |
output_attentions=output_attentions,
|
113 |
output_hidden_states=output_hidden_states,
|
|
|
121 |
)
|
122 |
|
123 |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
124 |
+
query_outputs = self.qformer(
|
125 |
query_embeds=query_tokens,
|
126 |
encoder_hidden_states=image_embeds,
|
127 |
encoder_attention_mask=image_attention_mask,
|
|
|
132 |
query_output = query_outputs[0]
|
133 |
|
134 |
# step 3: use the language model, conditioned on the query outputs and the prompt
|
135 |
+
language_model_inputs = self.language_projection(query_output)
|
136 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
137 |
+
if image_slot_offset is None:
|
138 |
+
# image as prefix
|
139 |
+
# update data to avoid inplace operation of leaf Variable
|
140 |
+
inputs_embeds.data[:, : self.config.num_query_tokens, :] = language_model_inputs
|
141 |
+
else:
|
142 |
+
for i, offset in enumerate(image_slot_offset):
|
143 |
+
inputs_embeds.data[i, offset : offset + self.config.num_query_tokens, :] = (
|
144 |
+
language_model_inputs[i]
|
145 |
+
)
|
146 |
|
147 |
+
outputs = self.language_model(
|
148 |
+
input_ids=input_ids,
|
149 |
+
inputs_embeds=inputs_embeds,
|
150 |
+
attention_mask=attention_mask,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
output_hidden_states=output_hidden_states,
|
153 |
+
return_dict=return_dict,
|
154 |
+
)
|
155 |
+
logits = outputs.logits if return_dict else outputs[0]
|
156 |
+
loss = None
|
157 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
158 |
+
if labels is not None:
|
159 |
+
logits = logits[:, -labels.size(1) :, :]
|
160 |
+
# Shift so that tokens < n predict n
|
161 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
162 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
163 |
+
|
164 |
+
# Flatten the tokens
|
165 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
166 |
+
|
167 |
+
loss = loss_fct(
|
168 |
+
shift_logits.view(-1, self.config.text_config.vocab_size),
|
169 |
+
shift_labels.view(-1),
|
170 |
+
)
|
171 |
|
172 |
+
if not return_dict:
|
173 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
174 |
+
return ((loss,) + output) if loss is not None else output
|
175 |
|
176 |
+
return Blip2ForConditionalGenerationModelOutput(
|
177 |
+
loss=loss,
|
178 |
+
logits=logits,
|
179 |
+
vision_outputs=vision_outputs,
|
180 |
+
qformer_outputs=query_outputs,
|
181 |
+
language_model_outputs=outputs,
|
182 |
+
)
|
|
|
|
|
183 |
|
184 |
@torch.no_grad()
|
185 |
def stream_chat(
|
|
|
193 |
do_sample=True,
|
194 |
temperature=1,
|
195 |
):
|
196 |
+
device = self.device
|
197 |
# 1. Prepare token ids
|
198 |
images = []
|
199 |
image_slots = []
|
200 |
|
201 |
+
nvtokens = self.config.num_query_tokens
|
202 |
if history:
|
203 |
input_ids = tokenizer(
|
204 |
f"[Round {len(history)}]\n问:", add_special_tokens=False
|
|
|
268 |
# 2. Prepare image embeddings
|
269 |
if len(images) != 0:
|
270 |
image = torch.cat(list(images), dim=0)
|
271 |
+
vision_outputs = self.vision_model.forward(image)
|
272 |
image_embeds = vision_outputs[0]
|
273 |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
274 |
device
|
275 |
)
|
276 |
|
277 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
278 |
+
query_outputs = self.qformer.forward(
|
279 |
query_embeds=query_tokens,
|
280 |
encoder_hidden_states=image_embeds,
|
281 |
encoder_attention_mask=image_atts,
|
282 |
)
|
283 |
query_output = query_outputs[0]
|
284 |
|
285 |
+
vtokens = self.language_projection(query_output)
|
286 |
else:
|
287 |
vtokens = []
|
288 |
|
289 |
# 3. Place image embeddings into slots
|
290 |
input_ids = torch.as_tensor(input_ids, dtype=torch.long).to(device).unsqueeze(0)
|
291 |
+
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
|
292 |
for slot, vimg in zip(image_slots, vtokens):
|
293 |
inputs_embeds[0][-slot : -slot + nvtokens, :] = vimg
|
294 |
|
|
|
303 |
"logits_processor": logits_processor,
|
304 |
}
|
305 |
|
306 |
+
for outputs in self.stream_generate(
|
307 |
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
|
308 |
):
|
309 |
outputs = outputs.tolist()[0][len(input_ids[0]) :]
|
310 |
response = tokenizer.decode(outputs)
|
311 |
+
response = self.language_model.process_response(response)
|
312 |
+
yield response
|
|
|
313 |
|
314 |
@torch.no_grad()
|
315 |
+
def stream_generate(
|
316 |
self,
|
317 |
input_ids,
|
318 |
inputs_embeds,
|
|
|
324 |
] = None,
|
325 |
**kwargs,
|
326 |
):
|
327 |
+
"""slightly modified from chatglm implementation to support inputs_embeds
|
328 |
+
|
329 |
+
Args:
|
330 |
+
input_ids (_type_): _description_
|
331 |
+
inputs_embeds (_type_): _description_
|
332 |
+
generation_config (Optional[GenerationConfig], optional): _description_. Defaults to None.
|
333 |
+
logits_processor (Optional[LogitsProcessorList], optional): _description_. Defaults to None.
|
334 |
+
stopping_criteria (Optional[StoppingCriteriaList], optional): _description_. Defaults to None.
|
335 |
+
prefix_allowed_tokens_fn (Optional[ Callable[[int, torch.Tensor], List[int]] ], optional): _description_. Defaults to None.
|
336 |
+
|
337 |
+
Yields:
|
338 |
+
_type_: _description_
|
339 |
+
"""
|
340 |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
341 |
|
342 |
if generation_config is None:
|
343 |
+
generation_config = self.language_model.generation_config
|
344 |
generation_config = copy.deepcopy(generation_config)
|
345 |
model_kwargs = generation_config.update(**kwargs)
|
346 |
bos_token_id, eos_token_id = (
|
|
|
378 |
if input_ids_seq_length >= generation_config.max_length:
|
379 |
input_ids_string = (
|
380 |
"decoder_input_ids"
|
381 |
+
if self.language_model.config.is_encoder_decoder
|
382 |
else "input_ids"
|
383 |
)
|
384 |
logger.warning(
|
|
|
397 |
else StoppingCriteriaList()
|
398 |
)
|
399 |
|
400 |
+
logits_processor = self.language_model._get_logits_processor(
|
401 |
generation_config=generation_config,
|
402 |
input_ids_seq_length=input_ids_seq_length,
|
403 |
encoder_input_ids=input_ids,
|
|
|
405 |
logits_processor=logits_processor,
|
406 |
)
|
407 |
|
408 |
+
stopping_criteria = self.language_model._get_stopping_criteria(
|
409 |
generation_config=generation_config, stopping_criteria=stopping_criteria
|
410 |
)
|
411 |
+
logits_warper = self.language_model._get_logits_warper(generation_config)
|
412 |
|
413 |
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
414 |
scores = None
|
415 |
while True:
|
416 |
+
model_inputs = self.prepare_inputs_for_generation(
|
417 |
input_ids, inputs_embeds=inputs_embeds, **model_kwargs
|
418 |
)
|
419 |
# forward pass to get next token
|
420 |
+
outputs = self.language_model(
|
421 |
**model_inputs,
|
422 |
return_dict=True,
|
423 |
output_attentions=False,
|
|
|
442 |
inputs_embeds = torch.cat(
|
443 |
[
|
444 |
inputs_embeds,
|
445 |
+
self.language_model.get_input_embeddings()(next_tokens)[:, None, :],
|
446 |
],
|
447 |
dim=1,
|
448 |
)
|
449 |
+
model_kwargs = self.language_model._update_model_kwargs_for_generation(
|
450 |
outputs,
|
451 |
model_kwargs,
|
452 |
+
is_encoder_decoder=self.language_model.config.is_encoder_decoder,
|
453 |
)
|
454 |
unfinished_sequences = unfinished_sequences.mul(
|
455 |
(sum(next_tokens != i for i in eos_token_id)).long()
|
|
|
459 |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
460 |
break
|
461 |
yield input_ids
|
462 |
+
|
463 |
+
def prepare_inputs_for_generation(
|
464 |
+
self,
|
465 |
+
input_ids: torch.LongTensor,
|
466 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
467 |
+
past: Optional[torch.Tensor] = None,
|
468 |
+
past_key_values: Optional[torch.Tensor] = None,
|
469 |
+
attention_mask: Optional[torch.Tensor] = None,
|
470 |
+
position_ids: Optional[torch.Tensor] = None,
|
471 |
+
**kwargs,
|
472 |
+
) -> dict:
|
473 |
+
"""slightly modified from chatglm implementation to support inputs_embeds
|
474 |
+
|
475 |
+
Args:
|
476 |
+
input_ids (torch.LongTensor): _description_
|
477 |
+
inputs_embeds (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
478 |
+
past (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
479 |
+
past_key_values (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
480 |
+
attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
481 |
+
position_ids (Optional[torch.Tensor], optional): _description_. Defaults to None.
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
dict: _description_
|
485 |
+
"""
|
486 |
+
batch_size, seq_length = input_ids.shape
|
487 |
+
MASK, gMASK = self.language_model.config.mask_token_id, self.language_model.config.gmask_token_id
|
488 |
+
seqs = input_ids.tolist()
|
489 |
+
mask_positions, use_gmasks = [], []
|
490 |
+
for seq in seqs:
|
491 |
+
mask_token = gMASK if gMASK in seq else MASK
|
492 |
+
use_gmask = mask_token == gMASK
|
493 |
+
mask_positions.append(seq.index(mask_token))
|
494 |
+
use_gmasks.append(use_gmask)
|
495 |
+
|
496 |
+
# only last token for input_ids if past is not None
|
497 |
+
if past is not None or past_key_values is not None:
|
498 |
+
last_token = input_ids[:, -1].unsqueeze(-1)
|
499 |
+
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
500 |
+
attention_mask = attention_mask[:, :, -1:]
|
501 |
+
else:
|
502 |
+
attention_mask = None
|
503 |
+
if position_ids is not None:
|
504 |
+
position_ids = position_ids[..., -1:]
|
505 |
+
else:
|
506 |
+
context_lengths = [seq.index(self.language_model.config.bos_token_id) for seq in seqs]
|
507 |
+
if self.language_model.position_encoding_2d:
|
508 |
+
position_ids = torch.tensor(
|
509 |
+
[
|
510 |
+
[mask_position, seq_length - context_length]
|
511 |
+
for mask_position, context_length in zip(
|
512 |
+
mask_positions, context_lengths
|
513 |
+
)
|
514 |
+
],
|
515 |
+
dtype=torch.long,
|
516 |
+
device=input_ids.device,
|
517 |
+
).unsqueeze(-1)
|
518 |
+
else:
|
519 |
+
position_ids = torch.tensor(
|
520 |
+
[mask_position for mask_position in mask_positions],
|
521 |
+
dtype=torch.long,
|
522 |
+
device=input_ids.device,
|
523 |
+
).unsqueeze(-1)
|
524 |
+
|
525 |
+
if past is None:
|
526 |
+
past = past_key_values
|
527 |
+
return {
|
528 |
+
"input_ids": last_token,
|
529 |
+
"past_key_values": past,
|
530 |
+
"position_ids": position_ids,
|
531 |
+
"attention_mask": attention_mask,
|
532 |
+
}
|
533 |
+
else:
|
534 |
+
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
535 |
+
logger.warning_once(
|
536 |
+
f"The dtype of attention mask ({attention_mask.dtype}) is not bool"
|
537 |
+
)
|
538 |
+
attention_mask = None
|
539 |
+
if attention_mask is None:
|
540 |
+
attention_mask = self.language_model.get_masks(input_ids, device=input_ids.device)
|
541 |
+
if position_ids is None:
|
542 |
+
position_ids = self.language_model.get_position_ids(
|
543 |
+
input_ids,
|
544 |
+
device=input_ids.device,
|
545 |
+
mask_positions=mask_positions,
|
546 |
+
use_gmasks=use_gmasks,
|
547 |
+
)
|
548 |
+
|
549 |
+
if inputs_embeds is not None:
|
550 |
+
assert input_ids.size(1) == inputs_embeds.size(
|
551 |
+
1
|
552 |
+
), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
|
553 |
+
return {
|
554 |
+
"inputs_embeds": inputs_embeds,
|
555 |
+
"past_key_values": past,
|
556 |
+
"position_ids": position_ids,
|
557 |
+
"attention_mask": attention_mask,
|
558 |
+
}
|
559 |
+
else:
|
560 |
+
return {
|
561 |
+
"input_ids": input_ids,
|
562 |
+
"past_key_values": past,
|
563 |
+
"position_ids": position_ids,
|
564 |
+
"attention_mask": attention_mask,
|
565 |
+
}
|
modeling_chatglm.py
CHANGED
@@ -55,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
55 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
56 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
57 |
scores.zero_()
|
58 |
-
scores[...,
|
59 |
return scores
|
60 |
|
61 |
|
@@ -280,10 +280,8 @@ def attention_fn(
|
|
280 |
# [sk, b, np, hn] -> [sk, b * np, hn]
|
281 |
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
282 |
|
283 |
-
matmul_result = torch.
|
284 |
-
|
285 |
-
output_size[2],
|
286 |
-
output_size[3],
|
287 |
dtype=query_layer.dtype,
|
288 |
device=query_layer.device,
|
289 |
)
|
@@ -348,10 +346,18 @@ def attention_fn(
|
|
348 |
return outputs
|
349 |
|
350 |
|
|
|
|
|
|
|
|
|
351 |
class SelfAttention(torch.nn.Module):
|
352 |
def __init__(self, hidden_size, num_attention_heads,
|
353 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
354 |
-
params_dtype=torch.float, position_encoding_2d=True):
|
|
|
|
|
|
|
|
|
355 |
super(SelfAttention, self).__init__()
|
356 |
|
357 |
self.layer_id = layer_id
|
@@ -379,7 +385,7 @@ class SelfAttention(torch.nn.Module):
|
|
379 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
380 |
|
381 |
# Strided linear layer.
|
382 |
-
self.query_key_value =
|
383 |
torch.nn.Linear,
|
384 |
hidden_size,
|
385 |
3 * self.inner_hidden_size,
|
@@ -387,7 +393,7 @@ class SelfAttention(torch.nn.Module):
|
|
387 |
dtype=params_dtype,
|
388 |
)
|
389 |
|
390 |
-
self.dense =
|
391 |
torch.nn.Linear,
|
392 |
self.inner_hidden_size,
|
393 |
hidden_size,
|
@@ -500,8 +506,12 @@ class GEGLU(torch.nn.Module):
|
|
500 |
|
501 |
class GLU(torch.nn.Module):
|
502 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
503 |
-
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
|
504 |
super(GLU, self).__init__()
|
|
|
|
|
|
|
|
|
505 |
self.layer_id = layer_id
|
506 |
self.activation_func = activation_func
|
507 |
|
@@ -510,7 +520,7 @@ class GLU(torch.nn.Module):
|
|
510 |
if inner_hidden_size is None:
|
511 |
inner_hidden_size = 4 * hidden_size
|
512 |
self.inner_hidden_size = inner_hidden_size
|
513 |
-
self.dense_h_to_4h =
|
514 |
torch.nn.Linear,
|
515 |
self.hidden_size,
|
516 |
self.inner_hidden_size,
|
@@ -518,7 +528,7 @@ class GLU(torch.nn.Module):
|
|
518 |
dtype=params_dtype,
|
519 |
)
|
520 |
# Project back to h.
|
521 |
-
self.dense_4h_to_h =
|
522 |
torch.nn.Linear,
|
523 |
self.inner_hidden_size,
|
524 |
self.hidden_size,
|
@@ -554,7 +564,8 @@ class GLMBlock(torch.nn.Module):
|
|
554 |
use_bias=True,
|
555 |
params_dtype=torch.float,
|
556 |
num_layers=28,
|
557 |
-
position_encoding_2d=True
|
|
|
558 |
):
|
559 |
super(GLMBlock, self).__init__()
|
560 |
# Set output layer initialization if not provided.
|
@@ -574,7 +585,8 @@ class GLMBlock(torch.nn.Module):
|
|
574 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
575 |
bias=use_bias,
|
576 |
params_dtype=params_dtype,
|
577 |
-
position_encoding_2d=self.position_encoding_2d
|
|
|
578 |
)
|
579 |
|
580 |
# Layernorm on the input data.
|
@@ -589,6 +601,7 @@ class GLMBlock(torch.nn.Module):
|
|
589 |
bias=use_bias,
|
590 |
layer_id=layer_id,
|
591 |
params_dtype=params_dtype,
|
|
|
592 |
)
|
593 |
|
594 |
def forward(
|
@@ -676,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
676 |
|
677 |
return attention_mask
|
678 |
|
679 |
-
def get_position_ids(self, input_ids, mask_positions, device,
|
680 |
batch_size, seq_length = input_ids.shape
|
|
|
|
|
681 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
682 |
if self.position_encoding_2d:
|
683 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
@@ -691,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
691 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
692 |
else:
|
693 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
694 |
-
|
695 |
-
|
696 |
position_ids[context_length:] = mask_positions[i]
|
697 |
|
698 |
return position_ids
|
@@ -783,9 +798,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
783 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
784 |
"""
|
785 |
|
786 |
-
def __init__(self, config: ChatGLMConfig):
|
787 |
super().__init__(config)
|
788 |
-
|
|
|
|
|
|
|
789 |
# recording parameters
|
790 |
self.max_sequence_length = config.max_sequence_length
|
791 |
self.hidden_size = config.hidden_size
|
@@ -800,7 +818,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
800 |
self.pre_seq_len = config.pre_seq_len
|
801 |
self.prefix_projection = config.prefix_projection
|
802 |
|
803 |
-
self.word_embeddings =
|
804 |
torch.nn.Embedding,
|
805 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
806 |
dtype=self.params_dtype
|
@@ -819,6 +837,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
819 |
use_bias=True,
|
820 |
params_dtype=self.params_dtype,
|
821 |
position_encoding_2d=self.position_encoding_2d,
|
|
|
822 |
)
|
823 |
|
824 |
self.layers = torch.nn.ModuleList(
|
@@ -894,12 +913,18 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
894 |
)
|
895 |
use_cache = False
|
896 |
|
897 |
-
if input_ids is not None and inputs_embeds is not None:
|
898 |
-
|
899 |
-
elif input_ids is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
900 |
batch_size, seq_length = input_ids.shape[:2]
|
901 |
elif inputs_embeds is not None:
|
902 |
-
# NOTE: fix
|
903 |
batch_size, seq_length = inputs_embeds.shape[:2]
|
904 |
else:
|
905 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
@@ -923,15 +948,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
923 |
|
924 |
if position_ids is None:
|
925 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
926 |
-
|
927 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
928 |
|
929 |
-
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
930 |
position_ids = self.get_position_ids(
|
931 |
input_ids,
|
932 |
mask_positions=mask_positions,
|
933 |
device=input_ids.device,
|
934 |
-
|
935 |
)
|
936 |
|
937 |
if self.pre_seq_len is not None and attention_mask is not None:
|
@@ -950,10 +980,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
950 |
if attention_mask is None:
|
951 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
952 |
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
|
958 |
for i, layer in enumerate(self.layers):
|
959 |
|
@@ -1009,8 +1039,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
1009 |
|
1010 |
|
1011 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1012 |
-
def __init__(self, config: ChatGLMConfig):
|
1013 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
1014 |
|
1015 |
# self.hidden_size = config.hidden_size
|
1016 |
# self.params_dtype = torch.half
|
@@ -1019,9 +1053,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1019 |
|
1020 |
self.position_encoding_2d = config.position_encoding_2d
|
1021 |
|
1022 |
-
self.transformer = ChatGLMModel(config)
|
1023 |
|
1024 |
-
self.lm_head =
|
1025 |
nn.Linear,
|
1026 |
config.hidden_size,
|
1027 |
config.vocab_size,
|
@@ -1080,7 +1114,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1080 |
def prepare_inputs_for_generation(
|
1081 |
self,
|
1082 |
input_ids: torch.LongTensor,
|
1083 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
1084 |
past: Optional[torch.Tensor] = None,
|
1085 |
past_key_values: Optional[torch.Tensor] = None,
|
1086 |
attention_mask: Optional[torch.Tensor] = None,
|
@@ -1089,10 +1122,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1089 |
) -> dict:
|
1090 |
batch_size, seq_length = input_ids.shape
|
1091 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
1092 |
-
mask_token = gMASK if gMASK in input_ids else MASK
|
1093 |
-
use_gmask = True if gMASK in input_ids else False
|
1094 |
seqs = input_ids.tolist()
|
1095 |
-
mask_positions = [
|
|
|
|
|
|
|
|
|
|
|
1096 |
|
1097 |
# only last token for input_ids if past is not None
|
1098 |
if past is not None or past_key_values is not None:
|
@@ -1135,23 +1171,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1135 |
input_ids,
|
1136 |
device=input_ids.device,
|
1137 |
mask_positions=mask_positions,
|
1138 |
-
|
1139 |
)
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
}
|
1148 |
-
else:
|
1149 |
-
return {
|
1150 |
-
"input_ids": input_ids,
|
1151 |
-
"past_key_values": past,
|
1152 |
-
"position_ids": position_ids,
|
1153 |
-
"attention_mask": attention_mask
|
1154 |
-
}
|
1155 |
|
1156 |
def forward(
|
1157 |
self,
|
|
|
55 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
56 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
57 |
scores.zero_()
|
58 |
+
scores[..., 5] = 5e4
|
59 |
return scores
|
60 |
|
61 |
|
|
|
280 |
# [sk, b, np, hn] -> [sk, b * np, hn]
|
281 |
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
282 |
|
283 |
+
matmul_result = torch.zeros(
|
284 |
+
1, 1, 1,
|
|
|
|
|
285 |
dtype=query_layer.dtype,
|
286 |
device=query_layer.device,
|
287 |
)
|
|
|
346 |
return outputs
|
347 |
|
348 |
|
349 |
+
def default_init(cls, *args, **kwargs):
|
350 |
+
return cls(*args, **kwargs)
|
351 |
+
|
352 |
+
|
353 |
class SelfAttention(torch.nn.Module):
|
354 |
def __init__(self, hidden_size, num_attention_heads,
|
355 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
356 |
+
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
357 |
+
if empty_init:
|
358 |
+
init_method = skip_init
|
359 |
+
else:
|
360 |
+
init_method = default_init
|
361 |
super(SelfAttention, self).__init__()
|
362 |
|
363 |
self.layer_id = layer_id
|
|
|
385 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
386 |
|
387 |
# Strided linear layer.
|
388 |
+
self.query_key_value = init_method(
|
389 |
torch.nn.Linear,
|
390 |
hidden_size,
|
391 |
3 * self.inner_hidden_size,
|
|
|
393 |
dtype=params_dtype,
|
394 |
)
|
395 |
|
396 |
+
self.dense = init_method(
|
397 |
torch.nn.Linear,
|
398 |
self.inner_hidden_size,
|
399 |
hidden_size,
|
|
|
506 |
|
507 |
class GLU(torch.nn.Module):
|
508 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
509 |
+
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
510 |
super(GLU, self).__init__()
|
511 |
+
if empty_init:
|
512 |
+
init_method = skip_init
|
513 |
+
else:
|
514 |
+
init_method = default_init
|
515 |
self.layer_id = layer_id
|
516 |
self.activation_func = activation_func
|
517 |
|
|
|
520 |
if inner_hidden_size is None:
|
521 |
inner_hidden_size = 4 * hidden_size
|
522 |
self.inner_hidden_size = inner_hidden_size
|
523 |
+
self.dense_h_to_4h = init_method(
|
524 |
torch.nn.Linear,
|
525 |
self.hidden_size,
|
526 |
self.inner_hidden_size,
|
|
|
528 |
dtype=params_dtype,
|
529 |
)
|
530 |
# Project back to h.
|
531 |
+
self.dense_4h_to_h = init_method(
|
532 |
torch.nn.Linear,
|
533 |
self.inner_hidden_size,
|
534 |
self.hidden_size,
|
|
|
564 |
use_bias=True,
|
565 |
params_dtype=torch.float,
|
566 |
num_layers=28,
|
567 |
+
position_encoding_2d=True,
|
568 |
+
empty_init=True
|
569 |
):
|
570 |
super(GLMBlock, self).__init__()
|
571 |
# Set output layer initialization if not provided.
|
|
|
585 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
586 |
bias=use_bias,
|
587 |
params_dtype=params_dtype,
|
588 |
+
position_encoding_2d=self.position_encoding_2d,
|
589 |
+
empty_init=empty_init
|
590 |
)
|
591 |
|
592 |
# Layernorm on the input data.
|
|
|
601 |
bias=use_bias,
|
602 |
layer_id=layer_id,
|
603 |
params_dtype=params_dtype,
|
604 |
+
empty_init=empty_init
|
605 |
)
|
606 |
|
607 |
def forward(
|
|
|
689 |
|
690 |
return attention_mask
|
691 |
|
692 |
+
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
|
693 |
batch_size, seq_length = input_ids.shape
|
694 |
+
if use_gmasks is None:
|
695 |
+
use_gmasks = [False] * batch_size
|
696 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
697 |
if self.position_encoding_2d:
|
698 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
706 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
707 |
else:
|
708 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
709 |
+
for i, context_length in enumerate(context_lengths):
|
710 |
+
if not use_gmasks[i]:
|
711 |
position_ids[context_length:] = mask_positions[i]
|
712 |
|
713 |
return position_ids
|
|
|
798 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
799 |
"""
|
800 |
|
801 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
802 |
super().__init__(config)
|
803 |
+
if empty_init:
|
804 |
+
init_method = skip_init
|
805 |
+
else:
|
806 |
+
init_method = default_init
|
807 |
# recording parameters
|
808 |
self.max_sequence_length = config.max_sequence_length
|
809 |
self.hidden_size = config.hidden_size
|
|
|
818 |
self.pre_seq_len = config.pre_seq_len
|
819 |
self.prefix_projection = config.prefix_projection
|
820 |
|
821 |
+
self.word_embeddings = init_method(
|
822 |
torch.nn.Embedding,
|
823 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
824 |
dtype=self.params_dtype
|
|
|
837 |
use_bias=True,
|
838 |
params_dtype=self.params_dtype,
|
839 |
position_encoding_2d=self.position_encoding_2d,
|
840 |
+
empty_init=empty_init
|
841 |
)
|
842 |
|
843 |
self.layers = torch.nn.ModuleList(
|
|
|
913 |
)
|
914 |
use_cache = False
|
915 |
|
916 |
+
# if input_ids is not None and inputs_embeds is not None:
|
917 |
+
# raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
918 |
+
# elif input_ids is not None:
|
919 |
+
# batch_size, seq_length = input_ids.shape[:2]
|
920 |
+
# elif inputs_embeds is not None:
|
921 |
+
# batch_size, seq_length. _ = inputs_embeds.shape[:2]
|
922 |
+
# else:
|
923 |
+
# raise ValueError("You have to specify either input_ids or inputs_embeds")
|
924 |
+
|
925 |
+
if input_ids is not None:
|
926 |
batch_size, seq_length = input_ids.shape[:2]
|
927 |
elif inputs_embeds is not None:
|
|
|
928 |
batch_size, seq_length = inputs_embeds.shape[:2]
|
929 |
else:
|
930 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
948 |
|
949 |
if position_ids is None:
|
950 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
951 |
+
seqs = input_ids.tolist()
|
952 |
+
|
953 |
+
mask_positions, use_gmasks = [], []
|
954 |
+
for seq in seqs:
|
955 |
+
mask_token = gMASK if gMASK in seq else MASK
|
956 |
+
use_gmask = mask_token == gMASK
|
957 |
+
mask_positions.append(seq.index(mask_token))
|
958 |
+
use_gmasks.append(use_gmask)
|
959 |
|
|
|
960 |
position_ids = self.get_position_ids(
|
961 |
input_ids,
|
962 |
mask_positions=mask_positions,
|
963 |
device=input_ids.device,
|
964 |
+
use_gmasks=use_gmasks
|
965 |
)
|
966 |
|
967 |
if self.pre_seq_len is not None and attention_mask is not None:
|
|
|
980 |
if attention_mask is None:
|
981 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
982 |
|
983 |
+
# NOTE: this is a hack to make the code work with the LAVIS training
|
984 |
+
# else:
|
985 |
+
# pass
|
986 |
+
# attention_mask = attention_mask.to(input_ids.device)
|
987 |
|
988 |
for i, layer in enumerate(self.layers):
|
989 |
|
|
|
1039 |
|
1040 |
|
1041 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1042 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
1043 |
super().__init__(config)
|
1044 |
+
if empty_init:
|
1045 |
+
init_method = skip_init
|
1046 |
+
else:
|
1047 |
+
init_method = default_init
|
1048 |
|
1049 |
# self.hidden_size = config.hidden_size
|
1050 |
# self.params_dtype = torch.half
|
|
|
1053 |
|
1054 |
self.position_encoding_2d = config.position_encoding_2d
|
1055 |
|
1056 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
1057 |
|
1058 |
+
self.lm_head = init_method(
|
1059 |
nn.Linear,
|
1060 |
config.hidden_size,
|
1061 |
config.vocab_size,
|
|
|
1114 |
def prepare_inputs_for_generation(
|
1115 |
self,
|
1116 |
input_ids: torch.LongTensor,
|
|
|
1117 |
past: Optional[torch.Tensor] = None,
|
1118 |
past_key_values: Optional[torch.Tensor] = None,
|
1119 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
1122 |
) -> dict:
|
1123 |
batch_size, seq_length = input_ids.shape
|
1124 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
|
|
|
|
1125 |
seqs = input_ids.tolist()
|
1126 |
+
mask_positions, use_gmasks = [], []
|
1127 |
+
for seq in seqs:
|
1128 |
+
mask_token = gMASK if gMASK in seq else MASK
|
1129 |
+
use_gmask = mask_token == gMASK
|
1130 |
+
mask_positions.append(seq.index(mask_token))
|
1131 |
+
use_gmasks.append(use_gmask)
|
1132 |
|
1133 |
# only last token for input_ids if past is not None
|
1134 |
if past is not None or past_key_values is not None:
|
|
|
1171 |
input_ids,
|
1172 |
device=input_ids.device,
|
1173 |
mask_positions=mask_positions,
|
1174 |
+
use_gmasks=use_gmasks
|
1175 |
)
|
1176 |
+
|
1177 |
+
return {
|
1178 |
+
"input_ids": input_ids,
|
1179 |
+
"past_key_values": past,
|
1180 |
+
"position_ids": position_ids,
|
1181 |
+
"attention_mask": attention_mask
|
1182 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1183 |
|
1184 |
def forward(
|
1185 |
self,
|
preprocessor_config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_convert_rgb": true,
|
3 |
+
"do_normalize": true,
|
4 |
+
"do_rescale": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.48145466,
|
8 |
+
0.4578275,
|
9 |
+
0.40821073
|
10 |
+
],
|
11 |
+
"image_processor_type": "BlipImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"processor_class": "Blip2Processor",
|
18 |
+
"resample": 3,
|
19 |
+
"rescale_factor": 0.00392156862745098,
|
20 |
+
"size": {
|
21 |
+
"height": 224,
|
22 |
+
"width": 224
|
23 |
+
}
|
24 |
+
}
|
pytorch_model-00001-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81ec9cc9a7e6034300a115898aac9fda06c69cf15d1b3c470d633ae7ce0ad3c9
|
3 |
+
size 1995030990
|
pytorch_model-00002-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de3166cf720b1a7cf6be0872f773d1f5e587e109435540f6511c37391827f1d6
|
3 |
+
size 1983142386
|
pytorch_model-00003-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a21fec7efe30123a73cd4bc77a4f8bf58c26e808743d47c606950215afd5e6c
|
3 |
+
size 1913134013
|
pytorch_model-00004-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b31bbd7aa605cde4258795220732aa24505ef451bf7e86a434c23c7fb75207e3
|
3 |
+
size 1879578439
|
pytorch_model-00005-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d45c57fd01a8e6b10a5d31d01af88580212e991705c5308ffcfe76bce8eb9df1
|
3 |
+
size 1879571453
|
pytorch_model-00006-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d1032cc31e7f2cda475a12f3a4016934c7d1c82c35b1cec93e159f0bbbc428c
|
3 |
+
size 1980242201
|
pytorch_model-00007-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e0088535e5adf2f7b2cc2064aff91ffb979fb895a8cb2e2eee14e97a358c192a
|
3 |
+
size 1913134077
|
pytorch_model-00008-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a45b9e99ae15bad8d1722abd5f6e441c6cb4fe87bfa32f6c01e5b0a58409ec5d
|
3 |
+
size 1208293115
|
pytorch_model-00009-of-00009.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6323bcad07ce5cc7934323c438abe9a8f45029553cd29098fe22314b14edb9a
|
3 |
+
size 1069286314
|
pytorch_model.bin.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<sop>",
|
3 |
+
"eos_token": "<eop>",
|
4 |
+
"mask_token": "[MASK]",
|
5 |
+
"pad_token": "<pad>",
|
6 |
+
"unk_token": "<unk>"
|
7 |
+
}
|
tokenization_chatglm.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tokenization classes for ChatGLM."""
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
import os
|
4 |
+
|
5 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
6 |
+
from transformers.utils import logging, PaddingStrategy
|
7 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
8 |
+
from typing import Dict
|
9 |
+
import sentencepiece as spm
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__)
|
13 |
+
|
14 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
15 |
+
"THUDM/chatglm-6b": 2048,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class TextTokenizer:
|
20 |
+
def __init__(self, model_path):
|
21 |
+
self.sp = spm.SentencePieceProcessor()
|
22 |
+
self.sp.Load(model_path)
|
23 |
+
self.num_tokens = self.sp.vocab_size()
|
24 |
+
|
25 |
+
def encode(self, text):
|
26 |
+
return self.sp.EncodeAsIds(text)
|
27 |
+
|
28 |
+
def decode(self, ids: List[int]):
|
29 |
+
return self.sp.DecodeIds(ids)
|
30 |
+
|
31 |
+
def tokenize(self, text):
|
32 |
+
return self.sp.EncodeAsPieces(text)
|
33 |
+
|
34 |
+
def convert_tokens_to_ids(self, tokens):
|
35 |
+
return [self.sp.PieceToId(token) for token in tokens]
|
36 |
+
|
37 |
+
def convert_token_to_id(self, token):
|
38 |
+
return self.sp.PieceToId(token)
|
39 |
+
|
40 |
+
def convert_id_to_token(self, idx):
|
41 |
+
return self.sp.IdToPiece(idx)
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return self.num_tokens
|
45 |
+
|
46 |
+
|
47 |
+
class SPTokenizer:
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
vocab_file,
|
51 |
+
num_image_tokens=20000,
|
52 |
+
max_blank_length=80,
|
53 |
+
byte_fallback=True,
|
54 |
+
):
|
55 |
+
assert vocab_file is not None
|
56 |
+
self.vocab_file = vocab_file
|
57 |
+
self.num_image_tokens = num_image_tokens
|
58 |
+
self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
|
59 |
+
self.max_blank_length = max_blank_length
|
60 |
+
self.byte_fallback = byte_fallback
|
61 |
+
self.text_tokenizer = TextTokenizer(vocab_file)
|
62 |
+
|
63 |
+
def _get_text_tokenizer(self):
|
64 |
+
return self.text_tokenizer
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def get_blank_token(length: int):
|
68 |
+
assert length >= 2
|
69 |
+
return f"<|blank_{length}|>"
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def get_tab_token():
|
73 |
+
return f"<|tab|>"
|
74 |
+
|
75 |
+
@property
|
76 |
+
def num_text_tokens(self):
|
77 |
+
return self.text_tokenizer.num_tokens
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_tokens(self):
|
81 |
+
return self.num_image_tokens + self.num_text_tokens
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def _encode_whitespaces(text: str, max_len: int = 80):
|
85 |
+
text = text.replace("\t", SPTokenizer.get_tab_token())
|
86 |
+
for i in range(max_len, 1, -1):
|
87 |
+
text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
|
88 |
+
return text
|
89 |
+
|
90 |
+
def _preprocess(self, text: str, linebreak=True, whitespaces=True):
|
91 |
+
if linebreak:
|
92 |
+
text = text.replace("\n", "<n>")
|
93 |
+
if whitespaces:
|
94 |
+
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
95 |
+
return text
|
96 |
+
|
97 |
+
def encode(
|
98 |
+
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
99 |
+
) -> List[int]:
|
100 |
+
"""
|
101 |
+
@param text: Text to encode.
|
102 |
+
@param linebreak: Whether to encode newline (\n) in text.
|
103 |
+
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
104 |
+
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
105 |
+
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
106 |
+
"""
|
107 |
+
text = self._preprocess(text, linebreak, whitespaces)
|
108 |
+
if not add_dummy_prefix:
|
109 |
+
text = "<n>" + text
|
110 |
+
tmp = self._get_text_tokenizer().encode(text)
|
111 |
+
tokens = [x + self.num_image_tokens for x in tmp]
|
112 |
+
return tokens if add_dummy_prefix else tokens[2:]
|
113 |
+
|
114 |
+
def decode(self, text_ids: List[int]) -> str:
|
115 |
+
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
116 |
+
ids = [_id for _id in ids if _id >= 0]
|
117 |
+
text = self._get_text_tokenizer().decode(ids)
|
118 |
+
text = text.replace("<n>", "\n")
|
119 |
+
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
120 |
+
for i in range(2, self.max_blank_length + 1):
|
121 |
+
text = text.replace(self.get_blank_token(i), " " * i)
|
122 |
+
return text
|
123 |
+
|
124 |
+
def tokenize(
|
125 |
+
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
126 |
+
) -> List[str]:
|
127 |
+
"""
|
128 |
+
@param text: Text to encode.
|
129 |
+
@param linebreak: Whether to encode newline (\n) in text.
|
130 |
+
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
131 |
+
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
132 |
+
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
133 |
+
"""
|
134 |
+
text = self._preprocess(text, linebreak, whitespaces)
|
135 |
+
if not add_dummy_prefix:
|
136 |
+
text = "<n>" + text
|
137 |
+
tokens = self._get_text_tokenizer().tokenize(text)
|
138 |
+
return tokens if add_dummy_prefix else tokens[2:]
|
139 |
+
|
140 |
+
def __getitem__(self, x: Union[int, str]):
|
141 |
+
if isinstance(x, int):
|
142 |
+
if x < self.num_image_tokens:
|
143 |
+
return "<image_{}>".format(x)
|
144 |
+
else:
|
145 |
+
return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
|
146 |
+
elif isinstance(x, str):
|
147 |
+
if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
|
148 |
+
return int(x[7:-1])
|
149 |
+
else:
|
150 |
+
return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
|
151 |
+
else:
|
152 |
+
raise ValueError("The key should be str or int.")
|
153 |
+
|
154 |
+
|
155 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
156 |
+
"""
|
157 |
+
Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
vocab_file (`str`):
|
161 |
+
Path to the vocabulary file.
|
162 |
+
"""
|
163 |
+
|
164 |
+
vocab_files_names = {"vocab_file": "ice_text.model"}
|
165 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
166 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
vocab_file,
|
171 |
+
do_lower_case=False,
|
172 |
+
remove_space=False,
|
173 |
+
bos_token='<sop>',
|
174 |
+
eos_token='<eop>',
|
175 |
+
end_token='</s>',
|
176 |
+
mask_token='[MASK]',
|
177 |
+
gmask_token='[gMASK]',
|
178 |
+
padding_side="left",
|
179 |
+
num_image_tokens=20000,
|
180 |
+
**kwargs
|
181 |
+
) -> None:
|
182 |
+
super().__init__(
|
183 |
+
do_lower_case=do_lower_case,
|
184 |
+
remove_space=remove_space,
|
185 |
+
padding_side=padding_side,
|
186 |
+
bos_token=bos_token,
|
187 |
+
eos_token=eos_token,
|
188 |
+
end_token=end_token,
|
189 |
+
mask_token=mask_token,
|
190 |
+
gmask_token=gmask_token,
|
191 |
+
num_image_tokens=num_image_tokens,
|
192 |
+
**kwargs
|
193 |
+
)
|
194 |
+
|
195 |
+
self.do_lower_case = do_lower_case
|
196 |
+
self.remove_space = remove_space
|
197 |
+
self.vocab_file = vocab_file
|
198 |
+
|
199 |
+
self.bos_token = bos_token
|
200 |
+
self.eos_token = eos_token
|
201 |
+
self.end_token = end_token
|
202 |
+
self.mask_token = mask_token
|
203 |
+
self.gmask_token = gmask_token
|
204 |
+
|
205 |
+
self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
|
206 |
+
|
207 |
+
""" Initialisation """
|
208 |
+
|
209 |
+
@property
|
210 |
+
def gmask_token_id(self) -> Optional[int]:
|
211 |
+
if self.gmask_token is None:
|
212 |
+
return None
|
213 |
+
return self.convert_tokens_to_ids(self.gmask_token)
|
214 |
+
|
215 |
+
@property
|
216 |
+
def end_token_id(self) -> Optional[int]:
|
217 |
+
"""
|
218 |
+
`Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
|
219 |
+
set.
|
220 |
+
"""
|
221 |
+
if self.end_token is None:
|
222 |
+
return None
|
223 |
+
return self.convert_tokens_to_ids(self.end_token)
|
224 |
+
|
225 |
+
@property
|
226 |
+
def vocab_size(self):
|
227 |
+
""" Returns vocab size """
|
228 |
+
return self.sp_tokenizer.num_tokens
|
229 |
+
|
230 |
+
def get_vocab(self):
|
231 |
+
""" Returns vocab as a dict """
|
232 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
233 |
+
vocab.update(self.added_tokens_encoder)
|
234 |
+
return vocab
|
235 |
+
|
236 |
+
def preprocess_text(self, inputs):
|
237 |
+
if self.remove_space:
|
238 |
+
outputs = " ".join(inputs.strip().split())
|
239 |
+
else:
|
240 |
+
outputs = inputs
|
241 |
+
|
242 |
+
if self.do_lower_case:
|
243 |
+
outputs = outputs.lower()
|
244 |
+
|
245 |
+
return outputs
|
246 |
+
|
247 |
+
def _tokenize(self, text, **kwargs):
|
248 |
+
""" Returns a tokenized string. """
|
249 |
+
text = self.preprocess_text(text)
|
250 |
+
|
251 |
+
seq = self.sp_tokenizer.tokenize(text)
|
252 |
+
|
253 |
+
return seq
|
254 |
+
|
255 |
+
def _decode(
|
256 |
+
self,
|
257 |
+
token_ids: Union[int, List[int]],
|
258 |
+
skip_special_tokens: bool = False,
|
259 |
+
clean_up_tokenization_spaces: bool = True,
|
260 |
+
**kwargs
|
261 |
+
) -> str:
|
262 |
+
if isinstance(token_ids, int):
|
263 |
+
token_ids = [token_ids]
|
264 |
+
if len(token_ids) == 0:
|
265 |
+
return ""
|
266 |
+
if self.pad_token_id in token_ids: # remove pad
|
267 |
+
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
268 |
+
return self.sp_tokenizer.decode(token_ids)
|
269 |
+
|
270 |
+
def _convert_token_to_id(self, token):
|
271 |
+
""" Converts a token (str) in an id using the vocab. """
|
272 |
+
return self.sp_tokenizer[token]
|
273 |
+
|
274 |
+
def _convert_id_to_token(self, index):
|
275 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
276 |
+
return self.sp_tokenizer[index]
|
277 |
+
|
278 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
279 |
+
"""
|
280 |
+
Save the vocabulary and special tokens file to a directory.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
save_directory (`str`):
|
284 |
+
The directory in which to save the vocabulary.
|
285 |
+
filename_prefix (`str`, *optional*):
|
286 |
+
An optional prefix to add to the named of the saved files.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
`Tuple(str)`: Paths to the files saved.
|
290 |
+
"""
|
291 |
+
if os.path.isdir(save_directory):
|
292 |
+
vocab_file = os.path.join(
|
293 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
vocab_file = save_directory
|
297 |
+
|
298 |
+
with open(self.vocab_file, 'rb') as fin:
|
299 |
+
proto_str = fin.read()
|
300 |
+
|
301 |
+
with open(vocab_file, "wb") as writer:
|
302 |
+
writer.write(proto_str)
|
303 |
+
|
304 |
+
return (vocab_file,)
|
305 |
+
|
306 |
+
def build_inputs_with_special_tokens(
|
307 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
308 |
+
) -> List[int]:
|
309 |
+
"""
|
310 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
311 |
+
adding special tokens. A BERT sequence has the following format:
|
312 |
+
|
313 |
+
- single sequence: `[CLS] X [SEP]`
|
314 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
315 |
+
|
316 |
+
Args:
|
317 |
+
token_ids_0 (`List[int]`):
|
318 |
+
List of IDs to which the special tokens will be added.
|
319 |
+
token_ids_1 (`List[int]`, *optional*):
|
320 |
+
Optional second list of IDs for sequence pairs.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
324 |
+
"""
|
325 |
+
mask_ids = self.sp_tokenizer[self.mask_token]
|
326 |
+
gmask_ids = self.sp_tokenizer[self.gmask_token]
|
327 |
+
eos_id = self.sp_tokenizer[self.eos_token]
|
328 |
+
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
|
329 |
+
token_ids_0 += [gmask_ids]
|
330 |
+
|
331 |
+
if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
|
332 |
+
token_ids_0 += [self.sp_tokenizer[self.end_token]]
|
333 |
+
|
334 |
+
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
335 |
+
|
336 |
+
if token_ids_1 is not None:
|
337 |
+
if not token_ids_1 or token_ids_1[-1] != eos_id:
|
338 |
+
token_ids_1 += [eos_id]
|
339 |
+
token_ids_0 += token_ids_1
|
340 |
+
|
341 |
+
return token_ids_0
|
342 |
+
|
343 |
+
def _pad(
|
344 |
+
self,
|
345 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
346 |
+
max_length: Optional[int] = None,
|
347 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
348 |
+
pad_to_multiple_of: Optional[int] = None,
|
349 |
+
return_attention_mask: Optional[bool] = None,
|
350 |
+
) -> dict:
|
351 |
+
"""
|
352 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
353 |
+
|
354 |
+
Args:
|
355 |
+
encoded_inputs:
|
356 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
357 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
358 |
+
Will truncate by taking into account the special tokens.
|
359 |
+
padding_strategy: PaddingStrategy to use for padding.
|
360 |
+
|
361 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
362 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
363 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
364 |
+
The tokenizer padding sides are defined in self.padding_side:
|
365 |
+
|
366 |
+
- 'left': pads on the left of the sequences
|
367 |
+
- 'right': pads on the right of the sequences
|
368 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
369 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
370 |
+
`>= 7.5` (Volta).
|
371 |
+
return_attention_mask:
|
372 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
373 |
+
"""
|
374 |
+
# Load from model defaults
|
375 |
+
bos_token_id = self.sp_tokenizer[self.bos_token]
|
376 |
+
mask_token_id = self.sp_tokenizer[self.mask_token]
|
377 |
+
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
378 |
+
assert self.padding_side == "left"
|
379 |
+
|
380 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
381 |
+
seq_length = len(required_input)
|
382 |
+
|
383 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
384 |
+
max_length = len(required_input)
|
385 |
+
|
386 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
387 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
388 |
+
|
389 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
390 |
+
|
391 |
+
# Initialize attention mask if not present.
|
392 |
+
if max_length is not None:
|
393 |
+
if "attention_mask" not in encoded_inputs:
|
394 |
+
if bos_token_id in required_input:
|
395 |
+
context_length = required_input.index(bos_token_id)
|
396 |
+
else:
|
397 |
+
context_length = seq_length
|
398 |
+
attention_mask = np.ones((1, seq_length, seq_length))
|
399 |
+
attention_mask = np.tril(attention_mask)
|
400 |
+
attention_mask[:, :, :context_length] = 1
|
401 |
+
attention_mask = np.bool_(attention_mask < 0.5)
|
402 |
+
encoded_inputs["attention_mask"] = attention_mask
|
403 |
+
|
404 |
+
if "position_ids" not in encoded_inputs:
|
405 |
+
position_ids = np.arange(seq_length, dtype=np.int64)
|
406 |
+
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
407 |
+
if mask_token in required_input:
|
408 |
+
mask_position = required_input.index(mask_token)
|
409 |
+
position_ids[context_length:] = mask_position
|
410 |
+
block_position_ids = np.concatenate(
|
411 |
+
[np.zeros(context_length, dtype=np.int64),
|
412 |
+
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
413 |
+
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
414 |
+
|
415 |
+
if needs_to_be_padded:
|
416 |
+
difference = max_length - len(required_input)
|
417 |
+
|
418 |
+
if "attention_mask" in encoded_inputs:
|
419 |
+
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
|
420 |
+
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
421 |
+
mode='constant', constant_values=True)
|
422 |
+
if "token_type_ids" in encoded_inputs:
|
423 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
424 |
+
"token_type_ids"
|
425 |
+
]
|
426 |
+
if "special_tokens_mask" in encoded_inputs:
|
427 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
428 |
+
if "position_ids" in encoded_inputs:
|
429 |
+
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
|
430 |
+
pad_width=[(0, 0), (difference, 0)])
|
431 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
432 |
+
|
433 |
+
return encoded_inputs
|
tokenizer_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoTokenizer": [
|
4 |
+
"tokenization_chatglm.ChatGLMTokenizer",
|
5 |
+
null
|
6 |
+
]
|
7 |
+
},
|
8 |
+
"bos_token": "<sop>",
|
9 |
+
"do_lower_case": false,
|
10 |
+
"end_token": "</s>",
|
11 |
+
"eos_token": "<eop>",
|
12 |
+
"gmask_token": "[gMASK]",
|
13 |
+
"mask_token": "[MASK]",
|
14 |
+
"model_max_length": 1000000000000000019884624838656,
|
15 |
+
"num_image_tokens": 0,
|
16 |
+
"pad_token": "<pad>",
|
17 |
+
"padding_side": "left",
|
18 |
+
"processor_class": "Blip2Processor",
|
19 |
+
"remove_space": false,
|
20 |
+
"special_tokens_map_file": null,
|
21 |
+
"tokenizer_class": "ChatGLMTokenizer",
|
22 |
+
"unk_token": "<unk>"
|
23 |
+
}
|