Mizukiluke commited on
Commit
2b356a5
1 Parent(s): eda056e

Upload 15 files

Browse files
README.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: visual-question-answering
6
+ tags:
7
+ - chat
8
+ ---
9
+
10
+ # mPLUG-Owl3
11
+
12
+ ## Introduction
13
+ mPLUG-Owl3 is a state-of-the-art multi-modal large language model designed to tackle the challenges of long image sequence understanding. We propose Hyper Attention, which boosts the speed of long visual sequence understanding in multimodal large language models by sixfold, allowing for processing of visual sequences that are eight times longer. Meanwhile, we maintain excellent performance on single-image, multi-image, and video tasks.
14
+
15
+ Github: [mPLUG-Owl](https://github.com/X-PLUG/mPLUG-Owl)
16
+
17
+ ## What's new
18
+ ```mPLUG-Owl3-7B-241101``` is a improved version of ```mPLUG-Owl3-7B-240728```.
19
+
20
+ ### Fused Hyper Attention
21
+ mPLUG-Owl3 requires separate calculations for cross-attention and self-attention, and fuses the outputs of both through a adaptive gate. Now, we use a unified operation that only requires computing attention once.
22
+
23
+ ### New template for media inputs
24
+ We now use the following format to represent the splited high-resolution images. In addition, we can now enable image splitting when the input consists of multiple images to achieve further performance benefits, which the old version of mPLUG-Owl3 was not trained to handle with this combination.
25
+ ```
26
+ <|start_cut|>2*3
27
+ <|image|> <|image|> <|image|>
28
+ <|image|> <|image|> <|image|>
29
+ <|image|><|end_cut|>
30
+ ```
31
+ And we use the following format to represent video.
32
+ ```
33
+ <|start_video_frame|><|image|><|image|><|image|><|end_video_frame|>
34
+ ```
35
+
36
+
37
+ ### Adjusted media_offset
38
+ Previously, media_offset recorded the range of images each token could see. During training, since the images from multiple samples are concatenated together along the batch dimension, media_offset needed to be carefully modified, otherwise it would point to the wrong image. To prevent this, media_offset is now a List[List[int]], representing the position of each image in a sample within the batch in the original sequence. This design also makes the computation of the cross-attention mask and MI-Rope more efficient and convenient.
39
+
40
+ **All of these changes are well handled by the processor, and you don't need to change the original way of calling it.**
41
+
42
+ ### High performance on video and multi-image scienario
43
+ | Model |NextQA |MVBench |VideoMME w/o sub| LongVideoBench-val| MLVU| LVBench|
44
+ |-|-|-|-|-|-|-|
45
+ | mPLUG-Owl3-7B-240728| 78.6 |54.5 |53.5 |52.1 |63.7|-|
46
+ | mPLUG-Owl3-7B-241101|82.3|59.5|59.3 |59.7|70.0|43.5|
47
+
48
+
49
+ | Model |NLVR2 |Mantis-Eval |MathVerse-mv| SciVerse-mv| BLINK |Q-Bench2|
50
+ |-|-|-|-|-|-|-|
51
+ | mPLUG-Owl3-7B-240728| 90.8 |63.1 |65.0 |86.2 |50.3 |74.0|
52
+ | mPLUG-Owl3-7B-241101|92.7|67.3|65.1 |82.7|53.8|77.7|
53
+
54
+
55
+
56
+ | Model |VQAv2 | OK-VQA | GQA | VizWizQA | TextVQA |
57
+ |-|-|-|-|-|-|
58
+ | mPLUG-Owl3-7B-240728|82.1 |60.1| 65.0| 63.5 |69.0|
59
+ | mPLUG-Owl3-7B-241101|83.2 |61.4| 64.7| 62.9 |71.4|
60
+
61
+ | Model | MMB-EN |MMB-CN |MM-Vet |POPE |AI2D|
62
+ |-|-|-|-|-|-|
63
+ | mPLUG-Owl3-7B-240728|77.6 |74.3 |40.1 |88.2 |73.8|
64
+ | mPLUG-Owl3-7B-241101|80.4 |79.1 |39.8 |88.1 |77.8|
65
+
66
+
67
+
68
+
69
+
70
+ ## Quickstart
71
+
72
+ Load the mPLUG-Owl3. We now only support attn_implementation in ```['sdpa', 'flash_attention_2']```.
73
+ ```Python
74
+ import torch
75
+ from modelscope import AutoConfig, AutoModel
76
+ model_path = 'iic/mPLUG-Owl3-2B-241101'
77
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
78
+ print(config)
79
+ model = AutoModel.from_pretrained(model_path, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True)
80
+ _ = model.eval().cuda()
81
+ device = "cuda"
82
+ ```
83
+
84
+ Chat with images.
85
+ ```Python
86
+ from PIL import Image
87
+
88
+ from modelscope import AutoTokenizer
89
+ from decord import VideoReader, cpu
90
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
91
+ processor = model.init_processor(tokenizer)
92
+
93
+ image = Image.new('RGB', (500, 500), color='red')
94
+
95
+ messages = [
96
+ {"role": "user", "content": """<|image|>
97
+ Describe this image."""},
98
+ {"role": "assistant", "content": ""}
99
+ ]
100
+
101
+ inputs = processor(messages, images=[image], videos=None)
102
+
103
+ inputs.to('cuda')
104
+ inputs.update({
105
+ 'tokenizer': tokenizer,
106
+ 'max_new_tokens':100,
107
+ 'decode_text':True,
108
+ })
109
+
110
+
111
+ g = model.generate(**inputs)
112
+ print(g)
113
+ ```
114
+
115
+ Chat with a video.
116
+ ```Python
117
+ from PIL import Image
118
+
119
+ from modelscope import AutoTokenizer
120
+ from decord import VideoReader, cpu # pip install decord
121
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
122
+ processor = model.init_processor(tokenizer)
123
+
124
+
125
+ messages = [
126
+ {"role": "user", "content": """<|video|>
127
+ Describe this video."""},
128
+ {"role": "assistant", "content": ""}
129
+ ]
130
+
131
+ videos = ['/nas-mmu-data/examples/car_room.mp4']
132
+
133
+ MAX_NUM_FRAMES=16
134
+
135
+ def encode_video(video_path):
136
+ def uniform_sample(l, n):
137
+ gap = len(l) / n
138
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
139
+ return [l[i] for i in idxs]
140
+
141
+ vr = VideoReader(video_path, ctx=cpu(0))
142
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
143
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
144
+ if len(frame_idx) > MAX_NUM_FRAMES:
145
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
146
+ frames = vr.get_batch(frame_idx).asnumpy()
147
+ frames = [Image.fromarray(v.astype('uint8')) for v in frames]
148
+ print('num frames:', len(frames))
149
+ return frames
150
+ video_frames = [encode_video(_) for _ in videos]
151
+ inputs = processor(messages, images=None, videos=video_frames)
152
+
153
+ inputs.to(device)
154
+ inputs.update({
155
+ 'tokenizer': tokenizer,
156
+ 'max_new_tokens':100,
157
+ 'decode_text':True,
158
+ })
159
+
160
+
161
+ g = model.generate(**inputs)
162
+ print(g)
163
+ ```
164
+
165
+ ### Save memory by Liger-Kernel
166
+ mPLUG-Owl3 is based on Qwen2, which can be optimized through the Liger-Kernel to reduce memory usage.
167
+ ```
168
+ pip install liger-kernel
169
+ ```
170
+
171
+ ```python
172
+ def apply_liger_kernel_to_mplug_owl3(
173
+ rms_norm: bool = True,
174
+ swiglu: bool = True,
175
+ model = None,
176
+ ) -> None:
177
+ from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module
178
+ from liger_kernel.transformers.monkey_patch import _bind_method_to_module
179
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
180
+ """
181
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
182
+
183
+ Args:
184
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
185
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
186
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
187
+ loaded. Default is None.
188
+ """
189
+
190
+ base_model = model.language_model.model
191
+
192
+ if rms_norm:
193
+ _patch_rms_norm_module(base_model.norm)
194
+
195
+ for decoder_layer in base_model.layers:
196
+ if swiglu:
197
+ _bind_method_to_module(
198
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
199
+ )
200
+ if rms_norm:
201
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
202
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
203
+ print("Applied Liger kernels to Qwen2 in mPLUG-Owl3")
204
+
205
+ import torch
206
+ from modelscope import AutoConfig, AutoModel
207
+ model_path = 'iic/mPLUG-Owl3-2B-241101'
208
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
209
+ print(config)
210
+ model = AutoModel.from_pretrained(model_path, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True)
211
+ _ = model.eval().cuda()
212
+ device = "cuda"
213
+ apply_liger_kernel_to_mplug_owl3(model=model)
214
+ ```
215
+
216
+ ### Save memory by setting device_map
217
+ When you have more than one GPUs, you can set the ```device_map='auto'``` to split the mPLUG-Owl3 into multiple GPUs. However, it will slowdown the inference speed.
218
+
219
+ ```python
220
+ model = AutoModel.from_pretrained(model_path, attn_implementation='flash_attention_2', device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
221
+ _ = model.eval()
222
+ first_layer_name = list(model.hf_device_map.keys())[0]
223
+ device = model.hf_device_map[first_layer_name]
224
+ ```
225
+
226
+ ## Citation
227
+
228
+ If you find our work helpful, feel free to give us a cite.
229
+
230
+ ```
231
+ @misc{ye2024mplugowl3longimagesequenceunderstanding,
232
+ title={mPLUG-Owl3: Towards Long Image-Sequence Understanding in Multi-Modal Large Language Models},
233
+ author={Jiabo Ye and Haiyang Xu and Haowei Liu and Anwen Hu and Ming Yan and Qi Qian and Ji Zhang and Fei Huang and Jingren Zhou},
234
+ year={2024},
235
+ eprint={2408.04840},
236
+ archivePrefix={arXiv},
237
+ primaryClass={cs.CV},
238
+ url={https://arxiv.org/abs/2408.04840},
239
+ }
240
+ ```
241
+
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "mPLUGOwl3Model"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mplugowl3.mPLUGOwl3Config",
7
+ "AutoModel": "modeling_mplugowl3.mPLUGOwl3Model",
8
+ "AutoModelForCausalLM": "modeling_mplugowl3.mPLUGOwl3Model"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 151643,
12
+ "eos_token_id": 151645,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 3584,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 18944,
17
+ "max_position_embeddings": 32768,
18
+ "max_window_layers": 28,
19
+ "model_type": "mplugowl3",
20
+ "num_attention_heads": 28,
21
+ "num_hidden_layers": 28,
22
+ "num_key_value_heads": 4,
23
+ "rms_norm_eps": 1e-06,
24
+ "rope_theta": 1000000.0,
25
+ "sliding_window": 131072,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.41.2",
29
+ "use_cache": true,
30
+ "use_sliding_window": false,
31
+ "vocab_size": 151651,
32
+ "hyper_layers": [
33
+ 7,
34
+ 15,
35
+ 23,
36
+ 26
37
+ ],
38
+ "vision_config": {
39
+ "hidden_size": 1152,
40
+ "image_size": 378,
41
+ "intermediate_size": 4304,
42
+ "model_type": "siglip_vision_model",
43
+ "num_attention_heads": 16,
44
+ "num_hidden_layers": 27,
45
+ "patch_size": 14
46
+ }
47
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"image-text-to-text"}
configuration_hyper_qwen2.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+
5
+
6
+ class HyperQwen2Config(PretrainedConfig):
7
+ r"""
8
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
9
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
10
+ with the defaults will yield a similar configuration to that of
11
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
12
+
13
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
+ documentation from [`PretrainedConfig`] for more information.
15
+
16
+
17
+ Args:
18
+ vocab_size (`int`, *optional*, defaults to 151936):
19
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
20
+ `inputs_ids` passed when calling [`Qwen2Model`]
21
+ hidden_size (`int`, *optional*, defaults to 4096):
22
+ Dimension of the hidden representations.
23
+ intermediate_size (`int`, *optional*, defaults to 22016):
24
+ Dimension of the MLP representations.
25
+ num_hidden_layers (`int`, *optional*, defaults to 32):
26
+ Number of hidden layers in the Transformer encoder.
27
+ num_attention_heads (`int`, *optional*, defaults to 32):
28
+ Number of attention heads for each attention layer in the Transformer encoder.
29
+ num_key_value_heads (`int`, *optional*, defaults to 32):
30
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
31
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
32
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
33
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
34
+ by meanpooling all the original heads within that group. For more details checkout [this
35
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
36
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
37
+ The non-linear activation function (function or string) in the decoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
39
+ The maximum sequence length that this model might ever be used with.
40
+ initializer_range (`float`, *optional*, defaults to 0.02):
41
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
42
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
43
+ The epsilon used by the rms normalization layers.
44
+ use_cache (`bool`, *optional*, defaults to `True`):
45
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
46
+ relevant if `config.is_decoder=True`.
47
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
48
+ Whether the model's input and output word embeddings should be tied.
49
+ rope_theta (`float`, *optional*, defaults to 10000.0):
50
+ The base period of the RoPE embeddings.
51
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
52
+ Whether to use sliding window attention.
53
+ sliding_window (`int`, *optional*, defaults to 4096):
54
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
55
+ max_window_layers (`int`, *optional*, defaults to 28):
56
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout ratio for the attention probabilities.
59
+
60
+ ```python
61
+ >>> from transformers import Qwen2Model, Qwen2Config
62
+
63
+ >>> # Initializing a Qwen2 style configuration
64
+ >>> configuration = Qwen2Config()
65
+
66
+ >>> # Initializing a model from the Qwen2-7B style configuration
67
+ >>> model = Qwen2Model(configuration)
68
+
69
+ >>> # Accessing the model configuration
70
+ >>> configuration = model.config
71
+ ```"""
72
+
73
+ model_type = "qwen2"
74
+ keys_to_ignore_at_inference = ["past_key_values"]
75
+
76
+ def __init__(
77
+ self,
78
+ vocab_size=151936,
79
+ hidden_size=4096,
80
+ intermediate_size=22016,
81
+ num_hidden_layers=32,
82
+ num_attention_heads=32,
83
+ num_key_value_heads=32,
84
+ hidden_act="silu",
85
+ max_position_embeddings=32768,
86
+ initializer_range=0.02,
87
+ rms_norm_eps=1e-6,
88
+ use_cache=True,
89
+ tie_word_embeddings=False,
90
+ rope_theta=10000.0,
91
+ use_sliding_window=False,
92
+ sliding_window=4096,
93
+ max_window_layers=28,
94
+ attention_dropout=0.0,
95
+ hyper_layers=[1,9,17,25],
96
+ vision_batch_size=16,
97
+ rope_scaling=None,
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.use_sliding_window = use_sliding_window
107
+ self.sliding_window = sliding_window if use_sliding_window else None
108
+ self.max_window_layers = max_window_layers
109
+ self.rope_scaling = rope_scaling
110
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
111
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
112
+ # for backward compatibility
113
+ if num_key_value_heads is None:
114
+ num_key_value_heads = num_attention_heads
115
+
116
+ self.num_key_value_heads = num_key_value_heads
117
+ self.hidden_act = hidden_act
118
+ self.initializer_range = initializer_range
119
+ self.rms_norm_eps = rms_norm_eps
120
+ self.use_cache = use_cache
121
+ self.rope_theta = rope_theta
122
+ self.attention_dropout = attention_dropout
123
+ self.hyper_layers = hyper_layers
124
+ self.vision_batch_size = vision_batch_size
125
+ super().__init__(
126
+ tie_word_embeddings=tie_word_embeddings,
127
+ **kwargs,
128
+ )
configuration_mplugowl3.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """ mPLUGOwl3 model configuration"""
3
+
4
+ import os
5
+ from typing import Union
6
+
7
+ from transformers.utils import logging
8
+ from .configuration_hyper_qwen2 import HyperQwen2Config
9
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class mPLUGOwl3Config(HyperQwen2Config):
14
+ model_type = "mplugowl3"
15
+ keys_to_ignore_at_inference = ["past_key_values"]
16
+
17
+ default_vision_config = {
18
+ "hidden_size": 1152,
19
+ "image_size": 378,
20
+ "intermediate_size": 4304,
21
+ "model_type": "siglip_vision_model",
22
+ "num_attention_heads": 16,
23
+ "num_hidden_layers": 27,
24
+ "patch_size": 14
25
+ }
26
+
27
+
28
+ def __init__(
29
+ self,
30
+ use_cache=True,
31
+ vision_config=None,
32
+ **kwargs,
33
+ ):
34
+ self.use_cache = use_cache
35
+
36
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
37
+ if vision_config is None:
38
+ self.vision_config = SiglipVisionConfig(**self.default_vision_config)
39
+ logger.info("vision_config is None, using default vision config")
40
+ elif isinstance(vision_config, dict):
41
+ self.vision_config = SiglipVisionConfig(**vision_config)
42
+ elif isinstance(vision_config, SiglipVisionConfig):
43
+ self.vision_config = vision_config
44
+ self.image_size = 378
45
+ self.patch_size = self.vision_config.patch_size
46
+
47
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "pad_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
+ "top_p": 0.8,
12
+ "top_k": 20,
13
+ "transformers_version": "4.37.0"
14
+ }
image_processing_mplugowl3.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional, Union, Dict, Any, List
3
+
4
+ from einops import rearrange, repeat
5
+ import torch
6
+ import math
7
+ import PIL.Image
8
+ import PIL.ImageSequence
9
+ import numpy as np
10
+ import PIL
11
+ from PIL import Image
12
+
13
+ from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
14
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
15
+ from transformers import AutoImageProcessor
16
+ from transformers.image_transforms import to_channel_dimension_format
17
+ from transformers.image_utils import (
18
+ ImageInput,
19
+ make_list_of_images,
20
+ valid_images,
21
+ is_torch_tensor,
22
+ is_batched,
23
+ to_numpy_array,
24
+ infer_channel_dimension_format,
25
+ ChannelDimension
26
+ )
27
+ from torchvision.ops.boxes import box_area
28
+ from torchvision.transforms import functional as F
29
+ from torchvision.transforms.transforms import InterpolationMode
30
+ from torchvision import transforms
31
+
32
+ def recursive_converter(converter, value):
33
+ if isinstance(value, list):
34
+ new_value = []
35
+ for v in value:
36
+ new_value += [recursive_converter(converter, v)]
37
+ return new_value
38
+ else:
39
+ return converter(value)
40
+
41
+ def box_iou(boxes1, area1, boxes2, eps=1e-5):
42
+ area2 = box_area(boxes2)
43
+
44
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
45
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
46
+
47
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
48
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
49
+
50
+ union = area1[:, None] + area2 - inter
51
+
52
+ iou = inter / (union+eps)
53
+ return iou, union
54
+
55
+ available_anchor_strategy = ['docowl', 'random', 'highest', 'last', 'llava']
56
+
57
+ grid_dict = {
58
+ 'grid_33':[
59
+ (1,1),
60
+ (1,2),(2,1),
61
+ (1,3),(3,1),
62
+ (2,2),(1,4),(4,1),
63
+ (1,5),(5,1),
64
+ (1,6),(6,1),(2,3),(3,2),
65
+ (1,7),(7,1),
66
+ (4,2),(2,4),(1,8),(8,1),
67
+ (3,3),(1,9),(9,1)],
68
+ 'grid_squ_3x3':[
69
+ (1,1),(2,2),(3,3)
70
+ ],
71
+ 'grid_squ_4':[
72
+ (2,2),(1,3),(1,4),(3,1),(4,1)
73
+ ],
74
+ 'grid_squ_6':[
75
+ (2,2),(1,3),(1,4),(3,1),(4,1), (2,3),(3,2)
76
+ ],
77
+ 'grid_squ_2':[
78
+ (2,1)
79
+ ],
80
+ 'grid_squ_9':[
81
+ (1,1),
82
+ (1,2),(2,1),
83
+ (1,3),(3,1),
84
+ (2,2),(1,4),(4,1),
85
+ (1,5),(5,1),
86
+ (1,6),(6,1),(2,3),(3,2),
87
+ (1,7),(7,1),
88
+ (4,2),(2,4),(1,8),(8,1),
89
+ (3,3),(1,9),(9,1)],
90
+ }
91
+
92
+ cut_prompt_template_dict = {
93
+ 'v0': lambda img_token, h, w: f''.join([f"{img_token}" for i in range(h) for j in range(w)]),
94
+ 'v1': lambda img_token, h, w: f'Cut to {h} rows {w} columns, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]),
95
+ 'v1_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]+[f"global_view{img_token}"]),
96
+ 'v2_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view\n'+ '\n'.join([' '.join([f"subimg({i},{j}){img_token}" for j in range(w)]) for i in range(h)])+f"\nglobal_view{img_token}",
97
+ 'v3': lambda img_token, h, w: f'<|start_cut|>{h}*{w}'+ ' '.join([f"{img_token}"for i in range(h) for j in range(w)])+'<|end_cut|>',
98
+ 'v3_global': lambda img_token, h, w: f'<|start_cut|>{h}*{w}\n'+ '\n'.join([' '.join([f"{img_token}" for j in range(w)]) for i in range(h)])+f'\n{img_token}<|end_cut|>',
99
+
100
+ }
101
+
102
+ def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
103
+ # anchors x1 y1 x2 y2
104
+
105
+ # image_size: (h, w)
106
+ # xyxy
107
+ input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
108
+
109
+ boxes1 = anchors
110
+ boxes2 = input_image_bbox
111
+ boxes3 = anchors.clone()
112
+ # y2
113
+ boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
114
+
115
+ area1 = anchors_areas
116
+
117
+ iou, _ = box_iou(boxes1, area1, boxes2)
118
+ iou = iou.squeeze(1)
119
+ shape_iou, _ = box_iou(boxes1, area1, boxes3)
120
+ shape_iou = shape_iou.diag()
121
+ # 优先匹配形状接近 再匹配分辨率接近
122
+ index = torch.argmax(shape_iou*100+iou,dim=0)
123
+ return index
124
+
125
+ def select_best_resolution(anchors, anchors_areas, input_image_size): # TODO For a futher check
126
+ """
127
+ Selects the best resolution from a list of possible resolutions based on the original size.
128
+
129
+ Args:
130
+ original_size (tuple): The original size of the image in the format (width, height).
131
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
132
+
133
+ Returns:
134
+ tuple: The best fit resolution in the format (width, height).
135
+ """
136
+ original_size = (input_image_size[1], input_image_size[0])
137
+ possible_resolutions = [(_[2], _[3]) for _ in anchors] # xyxy -> w,h
138
+
139
+ original_width, original_height = original_size
140
+ best_fit = None
141
+ max_effective_resolution = 0
142
+ min_wasted_resolution = float('inf')
143
+
144
+ index = 0
145
+ for i, (width, height) in enumerate(possible_resolutions):
146
+ scale = min(width / original_width, height / original_height)
147
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
148
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
149
+ wasted_resolution = (width * height) - effective_resolution
150
+
151
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
152
+ max_effective_resolution = effective_resolution
153
+ min_wasted_resolution = wasted_resolution
154
+ best_fit = (width, height)
155
+ index = i
156
+
157
+ return index
158
+
159
+ def build_cut_shape_indices(cut_shape):
160
+ # cut_shape: a list of (nh,nw)
161
+ cut_shape_indices = []
162
+ for shape in cut_shape:
163
+ n=shape[0]*shape[1]
164
+ indices = torch.cat([
165
+ repeat(torch.tensor(shape),'l -> n l',n=n),
166
+ torch.arange(n).unsqueeze(1)
167
+ ], dim=1)
168
+ assert indices.shape[0] == n
169
+ assert indices.shape[1] == 3 # nh,nw,idx
170
+
171
+ cut_shape_indices.append(indices)
172
+ cut_shape_indices = torch.cat(cut_shape_indices,dim=0).long()
173
+ return cut_shape_indices
174
+
175
+ class AnchorResize(torch.nn.Module):
176
+
177
+ def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None, anchor_strategy='docowl'):
178
+ super().__init__()
179
+ self.image_size = image_size
180
+ # xyxy
181
+ self.anchors = torch.tensor(
182
+ [[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
183
+ for _ in anchors], requires_grad=False
184
+ )
185
+
186
+ self.anchor_areas = box_area(self.anchors)
187
+
188
+ self.interpolation = interpolation
189
+ self.antialias = antialias
190
+ self.anchor_strategy = anchor_strategy
191
+ assert self.anchor_strategy in available_anchor_strategy
192
+
193
+ def resize_global(self, img):
194
+ return F.resize(img, self.image_size, self.interpolation, max_size=None, antialias=self.antialias)
195
+
196
+ def forward(self, img, skip_resize=False):
197
+ """
198
+ Args:
199
+ img (PIL Image or Tensor): Image to be scaled.
200
+
201
+ Returns:
202
+ PIL Image or Tensor: Rescaled image.
203
+ """
204
+ if self.anchor_strategy == 'docowl':
205
+ selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
206
+ elif self.anchor_strategy == 'random':
207
+ selected_anchor = random.randint(0,len(self.anchors)-1)
208
+ elif self.anchor_strategy == 'highest':
209
+ # 选面积最大的 在这个基础上 尽可能选最方正的
210
+ selected_anchor = torch.argmax(self.anchors[:,2]*self.anchors[:,3]*100-torch.abs(self.anchors[:,2]-self.anchors[:,3]))
211
+ elif self.anchor_strategy == 'last':
212
+ selected_anchor = len(self.anchors)-1
213
+ elif self.anchor_strategy == 'llava':
214
+ selected_anchor = select_best_resolution(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
215
+ else:
216
+ selected_anchor = None
217
+ assert selected_anchor is not None
218
+
219
+ target_size = self.anchors[selected_anchor][2:].tolist() # w,h
220
+ if skip_resize:
221
+ # for debug
222
+ return selected_anchor
223
+ return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
224
+
225
+ def __repr__(self) -> str:
226
+ detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
227
+ return f"{self.__class__.__name__}{detail}"
228
+
229
+ class CutMixin:
230
+ def __init__(self, cut_cfg={"anchors": "grid_squ_6", "anchor_strategy": "docowl", "cut_prompt": "v3", "add_global": True, "cut_prob": 1.0}) -> None:
231
+ if cut_cfg is None:
232
+ self.cut_enable = False
233
+ return
234
+ else:
235
+ self.cut_enable = True
236
+ image_size = self.image_size
237
+ anchors = cut_cfg.get('anchors','grid_33')
238
+ anchor_strategy = cut_cfg.get('anchor_strategy','docowl')
239
+ cut_prompt = cut_cfg.get('cut_prompt','v0')
240
+ self.cut_prob = cut_cfg.get('cut_prob', 1.0)
241
+
242
+ self.force_shape_cut = cut_cfg.get('force_shape_cut', False)
243
+ force_shape_cut_anchors = cut_cfg.get('force_shape_cut_anchors', 'force_shape_cut_anchors')
244
+
245
+
246
+ self.add_global = cut_cfg.get('add_global', False)
247
+
248
+ # h,w
249
+ if isinstance(image_size, int):
250
+ image_size = (image_size, image_size)
251
+ self.image_size = image_size
252
+
253
+ if anchors in grid_dict:
254
+ anchors = grid_dict[anchors]
255
+ else:
256
+ anchors = eval(anchors)
257
+ self.anchors = [tuple(_) for _ in anchors]
258
+ self.anchor_max = max([max(_) for _ in self.anchors])
259
+ self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC, anchor_strategy=anchor_strategy)
260
+
261
+ if force_shape_cut_anchors in grid_dict:
262
+ force_shape_cut_anchors = grid_dict[force_shape_cut_anchors]
263
+ else:
264
+ force_shape_cut_anchors = eval(force_shape_cut_anchors)
265
+ self.force_shape_cut_anchors = [tuple(_) for _ in force_shape_cut_anchors]
266
+ self.force_shape_cut_anchors_max = max([max(_) for _ in self.force_shape_cut_anchors])
267
+
268
+
269
+
270
+ self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
271
+
272
+ # 把image processor的缩放去掉 只保留后面的变换
273
+ self.image_transform = transforms.Compose(self.image_transform.transforms[1:])
274
+ if self.add_global:
275
+ self.cut_prompt_template = cut_prompt_template_dict[cut_prompt+'_global']
276
+ else:
277
+ self.cut_prompt_template = cut_prompt_template_dict[cut_prompt]
278
+
279
+ self.media_tokens = ["<|image|>", "<|video|>"]
280
+
281
+
282
+
283
+ def _process_image(self, images):
284
+ new_images = []
285
+ cut_shape = []
286
+ for image in images:
287
+ raw_image = image
288
+ image, selected_anchor = self.resizer(image)
289
+ image_input = self.image_transform(image) # h,w,3 -> 3,h,w
290
+ cut_shape.append((image_input.shape[1]//self.image_size[0], image_input.shape[2]//self.image_size[1])) # cut_h, cut_w
291
+ image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
292
+
293
+ new_images.append(image_input)
294
+
295
+ if self.add_global:
296
+ new_images.append(self.image_transform(self.resizer.resize_global(raw_image)).unsqueeze(0))
297
+ cut_shape.append((1,1))
298
+
299
+ new_images = torch.cat(new_images,dim=0)
300
+ cut_shape_indices = build_cut_shape_indices(cut_shape)
301
+ return new_images, cut_shape, cut_shape_indices
302
+
303
+ class mPLUGOwl3BatchFeature(BatchFeature):
304
+ r"""
305
+ Extend from BatchFeature for supporting various image size
306
+ """
307
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
308
+ super().__init__(data)
309
+ self.convert_to_tensors(tensor_type=tensor_type)
310
+
311
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
312
+ if tensor_type is None:
313
+ return self
314
+
315
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
316
+
317
+ def converter(value):
318
+ try:
319
+ if not is_tensor(value):
320
+ tensor = as_tensor(value)
321
+ return tensor
322
+ except: # noqa E722
323
+ if key == "overflowing_values":
324
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
325
+ raise ValueError(
326
+ "Unable to create tensor, you should probably activate padding "
327
+ "with 'padding=True' to have batched tensors with the same length."
328
+ )
329
+
330
+
331
+ for key, value in self.items():
332
+ self[key] = recursive_converter(converter, value)
333
+ return self
334
+
335
+ def to(self, *args, **kwargs) -> "mPLUGOwl3BatchFeature":
336
+ requires_backends(self, ["torch"])
337
+ import torch
338
+
339
+ def cast_tensor(v):
340
+ # check if v is a floating point
341
+ if torch.is_floating_point(v):
342
+ # cast and send to device
343
+ return v.to(*args, **kwargs)
344
+ elif device is not None:
345
+ return v.to(device=device)
346
+ else:
347
+ return v
348
+
349
+ new_data = {}
350
+ device = kwargs.get("device")
351
+ # Check if the args are a device or a dtype
352
+ if device is None and len(args) > 0:
353
+ # device should be always the first argument
354
+ arg = args[0]
355
+ if is_torch_dtype(arg):
356
+ # The first argument is a dtype
357
+ pass
358
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
359
+ device = arg
360
+ else:
361
+ # it's something else
362
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
363
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
364
+ for k, v in self.items():
365
+ new_data[k] = recursive_converter(cast_tensor, v)
366
+ self.data = new_data
367
+ return self
368
+
369
+
370
+ class mPLUGOwl3ImageProcessor(BaseImageProcessor, CutMixin):
371
+ model_input_names = ["pixel_values"]
372
+
373
+ def __init__(
374
+ self,
375
+ image_size,
376
+ mean=[0.5, 0.5, 0.5],
377
+ std=[0.5, 0.5, 0.5],
378
+ **kwargs):
379
+ super().__init__(**kwargs)
380
+ self.image_size = image_size
381
+ self.image_transform = transforms.Compose([
382
+ transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
383
+ transforms.ToTensor(),
384
+ transforms.Normalize(mean, std),
385
+ ])
386
+ CutMixin.__init__(self)
387
+
388
+ def preprocess(
389
+ self,
390
+ images: Union[Image.Image, List[Image.Image]],
391
+ cut_enable=True,
392
+ **kwargs
393
+ ) -> mPLUGOwl3BatchFeature:
394
+ if isinstance(images, Image.Image):
395
+ images_list = [images]
396
+ else:
397
+ images_list = images
398
+
399
+ if self.cut_enable and cut_enable:
400
+ image_data, cut_shape, cut_shape_indices = self._process_image(images_list)
401
+ else:
402
+ image_data = [self.image_transform(self.resizer.resize_global(image)) for image in images_list]
403
+ image_data = torch.stack(image_data, dim=0)
404
+ cut_shape = cut_shape_indices = None
405
+
406
+ return mPLUGOwl3BatchFeature(data={'pixel_values': image_data, 'cut_shape':cut_shape, 'cut_shape_indices':cut_shape_indices})
407
+
408
+ def to_dict(self):
409
+ encoder_dict = super().to_dict()
410
+ pop_keys = ['image_transform', 'resizer', 'old_resizer', 'cut_prompt_template']
411
+ for pk in pop_keys:
412
+ encoder_dict.pop(pk, None)
413
+ return encoder_dict
414
+
415
+ AutoImageProcessor.register("mPLUGOwl3ImageProcessor", mPLUGOwl3ImageProcessor)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edfc18006dff25e33c6eca595d2f13cb79694bcf71cc099e8292d6fbee1814f7
3
+ size 16145196264
modeling_hyper_qwen2.py ADDED
@@ -0,0 +1,1321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Qwen2 model."""
21
+ import inspect
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ from einops import rearrange, repeat
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2FlashAttention2, Qwen2SdpaAttention
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_hyper_qwen2 import HyperQwen2Config
46
+
47
+
48
+
49
+
50
+ try:
51
+ from flash_attn.layers.rotary import apply_rotary_emb_func
52
+ from einops import rearrange
53
+
54
+ use_flash_rotary = True
55
+ print("use flash_attn rotary")
56
+ except ImportError:
57
+ use_flash_rotary = False
58
+ print("import flash_attn rotary fail")
59
+
60
+
61
+ try:
62
+ from torch.nn.attention.flex_attention import create_block_mask
63
+ from torch.nn.attention.flex_attention import flex_attention
64
+ flex_attention = torch.compile(flex_attention, dynamic=False)
65
+ except ImportError:
66
+ pass
67
+
68
+ logger = logging.get_logger(__name__)
69
+
70
+
71
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
72
+ _CONFIG_FOR_DOC = "HyperQwen2Config"
73
+
74
+
75
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
76
+ def _get_unpad_data(attention_mask):
77
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
78
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
79
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
80
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
81
+ return (
82
+ indices,
83
+ cu_seqlens,
84
+ max_seqlen_in_batch,
85
+ )
86
+
87
+
88
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
89
+ class Qwen2RMSNorm(nn.Module):
90
+ def __init__(self, hidden_size, eps=1e-6):
91
+ """
92
+ Qwen2RMSNorm is equivalent to T5LayerNorm
93
+ """
94
+ super().__init__()
95
+ self.weight = nn.Parameter(torch.ones(hidden_size))
96
+ self.variance_epsilon = eps
97
+
98
+ def forward(self, hidden_states):
99
+ input_dtype = hidden_states.dtype
100
+ hidden_states = hidden_states.to(torch.float32)
101
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
102
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
103
+ return self.weight * hidden_states.to(input_dtype)
104
+
105
+
106
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
107
+ class Qwen2RotaryEmbedding(nn.Module):
108
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
109
+ super().__init__()
110
+
111
+ self.dim = dim
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.base = base
114
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
115
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
116
+
117
+ # Build here to make `torch.jit.trace` work.
118
+ self._set_cos_sin_cache(
119
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
120
+ )
121
+
122
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
123
+ self.max_seq_len_cached = seq_len
124
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
125
+
126
+ freqs = torch.outer(t, self.inv_freq)
127
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
128
+ emb = torch.cat((freqs, freqs), dim=-1)
129
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
130
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
131
+
132
+ def forward(self, x, seq_len=None):
133
+ # x: [bs, num_attention_heads, seq_len, head_size]
134
+ if seq_len > self.max_seq_len_cached:
135
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
136
+
137
+ return (
138
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
139
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
140
+ )
141
+
142
+ class RotaryEmbedding(torch.nn.Module):
143
+ def __init__(self, dim, base=10000, use_fp32=False, use_outer_in_rope=False):
144
+ super().__init__()
145
+ self.dim = dim
146
+ self.base = base
147
+ self.use_fp32 = use_fp32
148
+ if use_fp32:
149
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
150
+ else:
151
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
152
+ self.register_buffer("inv_freq", inv_freq)
153
+
154
+ self._rotary_pos_emb_cache = None
155
+ self._seq_len_cached = 0
156
+ self.use_outer_in_rope = use_outer_in_rope
157
+ self._ntk_alpha_cached = 1.0
158
+
159
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
160
+ seqlen = max_seq_len + offset
161
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
162
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
163
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim))
164
+ self._seq_len_cached = seqlen
165
+ self._ntk_alpha_cached = ntk_alpha
166
+ seq = torch.arange(seqlen, device=self.inv_freq.device)
167
+ # Don't do einsum, it converts fp32 to fp16 # TODO: CHECK this
168
+ if self.use_outer_in_rope:
169
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
170
+ else:
171
+ freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
172
+ # first part even vector components, second part odd vector components,
173
+ # 2 * dim in dimension size
174
+ emb = torch.cat((freqs, freqs), dim=-1)
175
+ # emb [seq_length, .., dim]
176
+ from einops import rearrange
177
+ self._rotary_pos_emb_cache = rearrange(emb, 'n d -> n 1 1 d')
178
+
179
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
180
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
181
+ return self._rotary_pos_emb_cache[offset:offset + max_seq_len]
182
+
183
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
184
+ def rotate_half(x):
185
+ """Rotates half the hidden dims of the input."""
186
+ x1 = x[..., : x.shape[-1] // 2]
187
+ x2 = x[..., x.shape[-1] // 2 :]
188
+ return torch.cat((-x2, x1), dim=-1)
189
+
190
+
191
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
192
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
193
+ """Applies Rotary Position Embedding to the query and key tensors.
194
+
195
+ Args:
196
+ q (`torch.Tensor`): The query tensor.
197
+ k (`torch.Tensor`): The key tensor.
198
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
199
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
200
+ position_ids (`torch.Tensor`):
201
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
202
+ used to pass offsetted position ids when working with a KV-cache.
203
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
204
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
205
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
206
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
207
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
208
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
209
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
210
+ Returns:
211
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
212
+ """
213
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
214
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
215
+ q_embed = (q * cos) + (rotate_half(q) * sin)
216
+ k_embed = (k * cos) + (rotate_half(k) * sin)
217
+ return q_embed, k_embed
218
+
219
+
220
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
221
+ class Qwen2MLP(nn.Module):
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ self.config = config
225
+ self.hidden_size = config.hidden_size
226
+ self.intermediate_size = config.intermediate_size
227
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
228
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
229
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
230
+ self.act_fn = ACT2FN[config.hidden_act]
231
+
232
+ def forward(self, x):
233
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
234
+
235
+
236
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
237
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
238
+ """
239
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
240
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
241
+ """
242
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
243
+ if n_rep == 1:
244
+ return hidden_states
245
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
246
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
247
+
248
+
249
+
250
+
251
+ def _rotate_half(x):
252
+ """
253
+ change sign so the last dimension becomes [-odd, +even]
254
+ """
255
+ from einops import rearrange
256
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
257
+ x1, x2 = x.unbind(dim=-2)
258
+ return torch.cat((-x2, x1), dim=-1)
259
+
260
+ def apply_rotary_pos_emb_core(t, freqs, use_fp32=False, debug=False):
261
+ """
262
+ input tensor t is of shape [seq_length, ..., dim]
263
+ rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
264
+ check https://kexue.fm/archives/8265 for detailed formulas
265
+ """
266
+
267
+ if use_flash_rotary and use_fp32:
268
+ t_ = rearrange(t, 's b ... -> b s ...').contiguous()
269
+ if use_fp32:
270
+ t_ = t_.float()
271
+ freqs = freqs.squeeze(1).squeeze(1)
272
+ cos = freqs[:, :freqs.shape[-1] // 2].cos()
273
+ sin = freqs[:, :freqs.shape[-1] // 2].sin()
274
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
275
+ return rearrange(output, 'b s ... -> s b ...')
276
+
277
+ rot_dim = freqs.shape[-1]
278
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
279
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
280
+
281
+ if use_fp32:
282
+ t_ = t_.float()
283
+ t_pass_ = t_pass_.float()
284
+ # first part is cosine component
285
+ # second part is sine component, need to change signs with _rotate_half method
286
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
287
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
288
+
289
+ class HyperQwen2Attention(nn.Module):
290
+ """
291
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
292
+ and "Generating Long Sequences with Sparse Transformers".
293
+ """
294
+
295
+ def __init__(self, config: HyperQwen2Config, layer_idx: Optional[int] = None, is_hyper_enabled=False):
296
+ super().__init__()
297
+ self.config = config
298
+ self.layer_idx = layer_idx
299
+ if layer_idx is None:
300
+ logger.warning_once(
301
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
302
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
303
+ "when creating this class."
304
+ )
305
+
306
+ self.hidden_size = config.hidden_size
307
+ self.num_heads = config.num_attention_heads
308
+ self.head_dim = self.hidden_size // self.num_heads
309
+ self.num_key_value_heads = config.num_key_value_heads
310
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
311
+ self.max_position_embeddings = config.max_position_embeddings
312
+ self.rope_theta = config.rope_theta
313
+ self.is_causal = True
314
+ self.attention_dropout = config.attention_dropout
315
+
316
+ if (self.head_dim * self.num_heads) != self.hidden_size:
317
+ raise ValueError(
318
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
319
+ f" and `num_heads`: {self.num_heads})."
320
+ )
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
322
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
323
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
324
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
325
+
326
+ self.rotary_emb = Qwen2RotaryEmbedding(
327
+ self.head_dim,
328
+ max_position_embeddings=self.max_position_embeddings,
329
+ base=self.rope_theta,
330
+ )
331
+ self.rotary_emb_core = RotaryEmbedding(
332
+ self.head_dim, base=self.rope_theta, use_fp32=True, use_outer_in_rope=True
333
+ )
334
+ # Hyper Attention Modules
335
+ self.is_hyper_enabled = is_hyper_enabled
336
+ if self.is_hyper_enabled:
337
+ self.v_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias=True)
338
+
339
+ self.visual_cache={}
340
+
341
+ self.use_flexattention = True
342
+
343
+
344
+
345
+ def apply_mi_rope(self, key_layer, image_pos, length_each_img):
346
+ # input shape should be [s b h d]
347
+ key_layer = rearrange(key_layer, 'b h s d -> s b h d')
348
+ if self.rotary_emb_core.inv_freq.device!=key_layer.device:
349
+ self.rotary_emb_core.inv_freq = self.rotary_emb_core.inv_freq.to(key_layer.device)
350
+ rotary_pos_emb_max_seq_len = self.config.max_position_embeddings
351
+ ntk_alpha = 1
352
+ rotary_pos_emb = self.rotary_emb_core(rotary_pos_emb_max_seq_len, ntk_alpha=ntk_alpha)
353
+ assert rotary_pos_emb is not None
354
+
355
+ if isinstance(rotary_pos_emb, tuple):
356
+ rotary_pos_emb = rotary_pos_emb
357
+ else:
358
+ rotary_pos_emb = ((rotary_pos_emb,) * 2)
359
+
360
+
361
+ if rotary_pos_emb is not None:
362
+ q_pos_emb, k_pos_emb = rotary_pos_emb
363
+ # ic(key_layer.shape, k_pos_emb.shape)
364
+
365
+
366
+ k_pos_emb = repeat(k_pos_emb[image_pos], 'N_img b h d -> (N_img L) b h d', L=length_each_img) # N_img, dim
367
+
368
+ key_layer = apply_rotary_pos_emb_core(key_layer, k_pos_emb, use_fp32=True) # TODO difference
369
+ key_layer = rearrange(key_layer, 's b h d -> b h s d')
370
+ return key_layer
371
+
372
+
373
+ # def hyper_mask_always_true(b, h, q_idx, kv_idx):
374
+ # return q_idx>=0
375
+
376
+ # def causal(b, h, q_idx, kv_idx):
377
+ # return q_idx >= kv_idx
378
+
379
+
380
+ # def create_hyper_attention(media_starts_extend, q_len, kv_len, each_visual_len):
381
+
382
+ # visual_len = kv_len - q_len
383
+ # def hyper_mask_dynamic(b, h, q_idx, kv_idx):
384
+ # return torch.where(kv_idx<visual_len, q_idx>=media_starts_extend[kv_idx], causal(b, h, q_idx, kv_idx-visual_len))
385
+
386
+ # return create_block_mask(hyper_mask_dynamic, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, BLOCK_SIZE=128, _compile=True)
387
+
388
+ # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
389
+ class HyperQwen2SdpaAttention(HyperQwen2Attention):
390
+ """
391
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
392
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
393
+ SDPA API.
394
+ """
395
+
396
+ def hyperattention(self,hidden_states: torch.Tensor,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ position_ids: Optional[torch.LongTensor] = None,
399
+ image_embeds=None,
400
+ media_offset=None,
401
+ past_key_value: Optional[Cache] = None,
402
+ output_attentions: bool = False,
403
+ use_cache: bool = False,
404
+ )-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
405
+ bsz, q_len, _ = hidden_states.size()
406
+
407
+ query_states = self.q_proj(hidden_states)
408
+ key_states = self.k_proj(hidden_states)
409
+ value_states = self.v_proj(hidden_states)
410
+
411
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
412
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
413
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
414
+
415
+ kv_seq_len = key_states.shape[-2]
416
+ if past_key_value is not None:
417
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
418
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
419
+
420
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
421
+
422
+ if past_key_value is not None:
423
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
424
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
425
+
426
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
427
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
428
+
429
+ # add visual to kv
430
+ length_each_img = image_embeds.shape[1]
431
+ image_embeds = self.v_kv_proj(image_embeds)
432
+ image_start = 0
433
+ context_layer = []
434
+ for bi, media_starts in enumerate(media_offset):
435
+ num_images = media_starts.shape[0]
436
+ if num_images > 0:
437
+ if q_len == 1:
438
+ full_mask = torch.ones((1,1,1, num_images*length_each_img + kv_seq_len)).bool().to(query_states.device)
439
+ else:
440
+ causal_mask = torch.tril(torch.ones(q_len, kv_seq_len, dtype=torch.bool, device=query_states.device)).bool()
441
+ # 扩展维度以匹配 (bsz, 1, q_len, kv_seq_len)
442
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
443
+
444
+ matrix = torch.arange(q_len, device=media_offset[0].device).reshape(-1,1)
445
+ t2vmask = ~(matrix<media_starts.view(1, -1))
446
+ t2vmask = repeat(t2vmask, 'seq_t seq_v -> 1 1 seq_t (seq_v v_token)', v_token=length_each_img).to(query_states.device)
447
+ full_mask = torch.cat([t2vmask, causal_mask], dim=3) # unsqueeze batch dim (batch, 1, seq_q, seq_k)
448
+
449
+ curr_query_layer = query_states[bi:bi+1]
450
+ # order is sbhd
451
+ curr_visual_key_layer, curr_visual_value_layer = rearrange(image_embeds[image_start:image_start+num_images], 'BL Lv (H KV D) -> KV 1 H (BL Lv) D', KV=2, H=self.num_key_value_heads) # b h s d
452
+ image_start += num_images
453
+ # ic(media_starts)
454
+ curr_visual_key_layer = self.apply_mi_rope(curr_visual_key_layer, media_starts, length_each_img=length_each_img)
455
+
456
+ curr_visual_key_layer = repeat_kv(curr_visual_key_layer, self.num_key_value_groups)
457
+ curr_visual_value_layer = repeat_kv(curr_visual_value_layer, self.num_key_value_groups)
458
+
459
+ curr_key_layer = torch.cat([curr_visual_key_layer, key_states[bi:bi+1]], dim=2)
460
+ curr_value_layer = torch.cat([curr_visual_value_layer, value_states[bi:bi+1]], dim=2)
461
+ is_causal = False
462
+ else:
463
+ # 执行无图attention
464
+ curr_query_layer = query_states[bi:bi+1]
465
+ curr_key_layer = key_states[bi:bi+1]
466
+ curr_value_layer = value_states[bi:bi+1]
467
+ full_mask = causal_mask
468
+ is_causal = True if causal_mask is None and q_len > 1 else False
469
+ if is_causal:
470
+ full_mask = None
471
+ else:
472
+ causal_mask = torch.tril(torch.ones(q_len, kv_seq_len, dtype=torch.bool, device=query_states.device)).bool()
473
+ # 扩展维度以匹配 (bsz, 1, q_len, kv_seq_len)
474
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
475
+ full_mask = causal_mask
476
+
477
+
478
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
479
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
480
+ if curr_query_layer.device.type == "cuda" and full_mask is not None:
481
+ curr_query_layer = curr_query_layer.contiguous()
482
+ curr_key_layer = curr_key_layer.contiguous()
483
+ curr_value_layer = curr_value_layer.contiguous()
484
+
485
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
486
+ curr_query_layer, # (batch, ..., sequence, dim)
487
+ curr_key_layer,
488
+ curr_value_layer,
489
+ attn_mask=full_mask, # (N, ..., L, S) A boolean mask where a value of True indicates that the element *should* take part in attention.
490
+ dropout_p=self.attention_dropout if self.training else 0.0,
491
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
492
+ is_causal=is_causal,
493
+ # enable_gqa=True, # gqa can not be used because mask requires XFORMERS and not support gqa
494
+ ) # -> (N, ..., L, Ev)
495
+ assert attn_output.shape[0] == 1
496
+ context_layer.append(attn_output)
497
+ attn_output = context_layer = torch.cat(context_layer, dim=0)
498
+
499
+ attn_output = attn_output.transpose(1, 2).contiguous()
500
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
501
+
502
+ attn_output = self.o_proj(attn_output)
503
+
504
+ return attn_output, None, past_key_value
505
+
506
+ # def hyperattention_flex(self,hidden_states: torch.Tensor,
507
+ # attention_mask: Optional[torch.Tensor] = None,
508
+ # position_ids: Optional[torch.LongTensor] = None,
509
+ # image_embeds=None,
510
+ # media_offset=None,
511
+ # past_key_value: Optional[Cache] = None,
512
+ # output_attentions: bool = False,
513
+ # use_cache: bool = False,
514
+ # )-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
515
+ # bsz, q_len, _ = hidden_states.size()
516
+
517
+ # query_states = self.q_proj(hidden_states)
518
+ # key_states = self.k_proj(hidden_states)
519
+ # value_states = self.v_proj(hidden_states)
520
+
521
+ # query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
522
+ # key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
523
+ # value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
524
+
525
+ # kv_seq_len = key_states.shape[-2]
526
+ # if past_key_value is not None:
527
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
528
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
529
+
530
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
531
+
532
+ # if past_key_value is not None:
533
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
534
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
535
+
536
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
537
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
538
+
539
+ # # add visual to kv
540
+ # length_each_img = image_embeds.shape[1]
541
+ # image_embeds = self.v_kv_proj(image_embeds)
542
+ # image_start = 0
543
+ # context_layer = []
544
+
545
+
546
+ # for bi, media_starts in enumerate(media_offset):
547
+ # num_images = media_starts.shape[0]
548
+ # if num_images > 0:
549
+ # if q_len == 1:
550
+ # hyper_maks = create_block_mask(hyper_mask_always_true, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_seq_len+len(media_starts)*length_each_img)
551
+ # else:
552
+ # media_starts_extend = repeat(media_starts, 'seq_v -> (seq_v v_token)', v_token=length_each_img)
553
+ # extend_len = media_starts_extend.shape[0]+kv_seq_len
554
+ # if extend_len%128!=0:
555
+ # extend_len = (extend_len//128+1)*128
556
+ # extend_len = extend_len-media_starts_extend.shape[0]
557
+ # media_starts_extend = torch.cat([media_starts_extend, torch.zeros(extend_len, device=media_starts_extend.device, dtype=media_starts_extend.dtype)],dim=0)
558
+ # hyper_maks = create_hyper_attention(media_starts_extend, q_len, kv_seq_len+len(media_starts)*length_each_img, length_each_img)
559
+
560
+ # curr_query_layer = query_states[bi:bi+1]
561
+ # # order is sbhd
562
+ # curr_visual_key_layer, curr_visual_value_layer = rearrange(image_embeds[image_start:image_start+num_images], 'BL Lv (H KV D) -> KV 1 H (BL Lv) D', KV=2, H=self.num_key_value_heads) # b h s d
563
+ # image_start += num_images
564
+ # # ic(media_starts)
565
+ # curr_visual_key_layer = self.apply_mi_rope(curr_visual_key_layer, media_starts, length_each_img=length_each_img)
566
+
567
+ # curr_visual_key_layer = repeat_kv(curr_visual_key_layer, self.num_key_value_groups)
568
+ # curr_visual_value_layer = repeat_kv(curr_visual_value_layer, self.num_key_value_groups)
569
+
570
+ # curr_key_layer = torch.cat([curr_visual_key_layer, key_states[bi:bi+1]], dim=2)
571
+ # curr_value_layer = torch.cat([curr_visual_value_layer, value_states[bi:bi+1]], dim=2)
572
+ # is_causal = False
573
+ # else:
574
+ # # 执行无图attention
575
+ # curr_query_layer = query_states[bi:bi+1]
576
+ # curr_key_layer = key_states[bi:bi+1]
577
+ # curr_value_layer = value_states[bi:bi+1]
578
+ # full_mask = causal_mask
579
+ # is_causal = True if causal_mask is None and q_len > 1 else False
580
+ # if is_causal:
581
+ # hyper_maks = create_block_mask(hyper_mask_always_true, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_seq_len, _compile=True)
582
+ # else:
583
+ # hyper_maks = create_block_mask(causal, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_seq_len, _compile=True)
584
+
585
+
586
+ # # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
587
+ # # Reference: https://github.com/pytorch/pytorch/issues/112577.
588
+ # if curr_query_layer.device.type == "cuda" and attention_mask is not None:
589
+ # curr_query_layer = curr_query_layer.contiguous()
590
+ # curr_key_layer = curr_key_layer.contiguous()
591
+ # curr_value_layer = curr_value_layer.contiguous()
592
+
593
+
594
+ # attn_output = flex_attention(
595
+ # curr_query_layer,
596
+ # curr_key_layer,
597
+ # curr_value_layer,
598
+ # block_mask=hyper_maks
599
+ # )
600
+ # # attn_output = torch.nn.functional.scaled_dot_product_attention(
601
+ # # curr_query_layer, # (batch, ..., sequence, dim)
602
+ # # curr_key_layer,
603
+ # # curr_value_layer,
604
+ # # attn_mask=full_mask, # (N, ..., L, S) A boolean mask where a value of True indicates that the element *should* take part in attention.
605
+ # # dropout_p=self.attention_dropout if self.training else 0.0,
606
+ # # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
607
+ # # is_causal=is_causal,
608
+ # # ) # -> (N, ..., L, Ev)
609
+ # assert attn_output.shape[0] == 1
610
+ # context_layer.append(attn_output)
611
+ # attn_output = context_layer = torch.cat(context_layer, dim=0)
612
+
613
+ # attn_output = attn_output.transpose(1, 2).contiguous()
614
+ # attn_output = attn_output.view(bsz, q_len, self.hidden_size)
615
+
616
+ # attn_output = self.o_proj(attn_output)
617
+
618
+ # return attn_output, None, past_key_value
619
+
620
+ # Adapted from Qwen2Attention.forward
621
+ def forward(
622
+ self,
623
+ hidden_states: torch.Tensor,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ position_ids: Optional[torch.LongTensor] = None,
626
+ image_embeds=None,
627
+ media_offset=None,
628
+ past_key_value: Optional[Cache] = None,
629
+ output_attentions: bool = False,
630
+ use_cache: bool = False,
631
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
632
+ if output_attentions:
633
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
634
+ logger.warning_once(
635
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
636
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
637
+ )
638
+ return super().forward(
639
+ hidden_states=hidden_states,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_value=past_key_value,
643
+ output_attentions=output_attentions,
644
+ use_cache=use_cache,
645
+ )
646
+ if self.is_hyper_enabled and image_embeds is not None:
647
+ # return self.hyperattention_flex(hidden_states, attention_mask, position_ids, image_embeds, media_offset, past_key_value, output_attentions, use_cache)
648
+ return self.hyperattention(hidden_states, attention_mask, position_ids, image_embeds, media_offset, past_key_value, output_attentions, use_cache)
649
+
650
+ bsz, q_len, _ = hidden_states.size()
651
+
652
+ query_states = self.q_proj(hidden_states)
653
+ key_states = self.k_proj(hidden_states)
654
+ value_states = self.v_proj(hidden_states)
655
+
656
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
657
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
658
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
659
+
660
+ kv_seq_len = key_states.shape[-2]
661
+ if past_key_value is not None:
662
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
663
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
664
+
665
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
666
+
667
+ if past_key_value is not None:
668
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
669
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
670
+
671
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
672
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
673
+
674
+ if attention_mask is not None:
675
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
676
+ raise ValueError(
677
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
678
+ )
679
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
680
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
681
+ if query_states.device.type == "cuda" and attention_mask is not None:
682
+ query_states = query_states.contiguous()
683
+ key_states = key_states.contiguous()
684
+ value_states = value_states.contiguous()
685
+
686
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
687
+ query_states,
688
+ key_states,
689
+ value_states,
690
+ attn_mask=attention_mask,
691
+ dropout_p=self.attention_dropout if self.training else 0.0,
692
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
693
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
694
+ )
695
+
696
+ attn_output = attn_output.transpose(1, 2).contiguous()
697
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
698
+
699
+ attn_output = self.o_proj(attn_output)
700
+
701
+ return attn_output, None, past_key_value
702
+
703
+
704
+ # Original Attention of Qwen2
705
+ QWEN2_ATTENTION_CLASSES = {
706
+ "eager": Qwen2Attention,
707
+ "flash_attention_2": Qwen2FlashAttention2,
708
+ "sdpa": Qwen2SdpaAttention,
709
+ }
710
+
711
+
712
+ class HyperQwen2DecoderLayer(nn.Module):
713
+ def __init__(self, config: HyperQwen2Config, layer_idx: int):
714
+ super().__init__()
715
+ self.hidden_size = config.hidden_size
716
+
717
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
718
+ logger.warning_once(
719
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
720
+ "unexpected results may be encountered."
721
+ )
722
+ self.is_hyper_enabled = (layer_idx+1) in config.hyper_layers
723
+ if self.is_hyper_enabled:
724
+ self.self_attn = HyperQwen2SdpaAttention(config, layer_idx, is_hyper_enabled=self.is_hyper_enabled)
725
+ else:
726
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
727
+
728
+
729
+ self.mlp = Qwen2MLP(config)
730
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
731
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
732
+
733
+ @property
734
+ def device(self):
735
+ return self.input_layernorm.weight.device
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.Tensor,
740
+ attention_mask: Optional[torch.Tensor] = None,
741
+ position_ids: Optional[torch.LongTensor] = None,
742
+ image_embeds=None,
743
+ media_offset=None,
744
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
745
+ output_attentions: Optional[bool] = False,
746
+ use_cache: Optional[bool] = False,
747
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
752
+ `(batch, sequence_length)` where padding elements are indicated by 0.
753
+ output_attentions (`bool`, *optional*):
754
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
755
+ returned tensors for more detail.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
+ (see `past_key_values`).
759
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
760
+ """
761
+ # if hidden_states.device != self.device:
762
+ # hidden_states = hidden_states.to(self.device)
763
+ residual = hidden_states
764
+ hidden_states = self.input_layernorm(hidden_states)
765
+
766
+ # Shared LayerNorm
767
+ if image_embeds is not None and self.is_hyper_enabled:
768
+ image_embeds = self.input_layernorm(image_embeds)
769
+ media_kwargs = {"image_embeds": image_embeds, "media_offset": media_offset}
770
+ else:
771
+ image_embeds = media_offset = None
772
+ media_kwargs = {}
773
+ # Self Attention
774
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
775
+ hidden_states=hidden_states,
776
+ attention_mask=attention_mask,
777
+ position_ids=position_ids,
778
+ past_key_value=past_key_value,
779
+ output_attentions=output_attentions,
780
+ use_cache=use_cache,
781
+ **media_kwargs,
782
+ )
783
+ hidden_states = residual + hidden_states
784
+
785
+ # Fully Connected
786
+ residual = hidden_states
787
+ hidden_states = self.post_attention_layernorm(hidden_states)
788
+ hidden_states = self.mlp(hidden_states)
789
+ hidden_states = residual + hidden_states
790
+
791
+ outputs = (hidden_states,)
792
+
793
+ if output_attentions:
794
+ outputs += (self_attn_weights,)
795
+
796
+ if use_cache:
797
+ outputs += (present_key_value,)
798
+
799
+ return outputs
800
+
801
+
802
+ QWEN2_START_DOCSTRING = r"""
803
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
804
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
805
+ etc.)
806
+
807
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
808
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
809
+ and behavior.
810
+
811
+ Parameters:
812
+ config ([`HyperQwen2Config`]):
813
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
814
+ load the weights associated with the model, only the configuration. Check out the
815
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
816
+ """
817
+
818
+
819
+ @add_start_docstrings(
820
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
821
+ QWEN2_START_DOCSTRING,
822
+ )
823
+ class Qwen2PreTrainedModel(PreTrainedModel):
824
+ config_class = HyperQwen2Config
825
+ base_model_prefix = "model"
826
+ supports_gradient_checkpointing = True
827
+ _no_split_modules = ["HyperQwen2DecoderLayer"]
828
+ _skip_keys_device_placement = "past_key_values"
829
+ _supports_flash_attn_2 = True
830
+ _supports_sdpa = True
831
+ _supports_cache_class = True
832
+
833
+ def _init_weights(self, module):
834
+ std = self.config.initializer_range
835
+ if isinstance(module, nn.Linear):
836
+ module.weight.data.normal_(mean=0.0, std=std)
837
+ if module.bias is not None:
838
+ module.bias.data.zero_()
839
+ elif isinstance(module, nn.Embedding):
840
+ module.weight.data.normal_(mean=0.0, std=std)
841
+ if module.padding_idx is not None:
842
+ module.weight.data[module.padding_idx].zero_()
843
+
844
+
845
+ QWEN2_INPUTS_DOCSTRING = r"""
846
+ Args:
847
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
848
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
849
+ it.
850
+
851
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
852
+ [`PreTrainedTokenizer.__call__`] for details.
853
+
854
+ [What are input IDs?](../glossary#input-ids)
855
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
856
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
857
+
858
+ - 1 for tokens that are **not masked**,
859
+ - 0 for tokens that are **masked**.
860
+
861
+ [What are attention masks?](../glossary#attention-mask)
862
+
863
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
864
+ [`PreTrainedTokenizer.__call__`] for details.
865
+
866
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
867
+ `past_key_values`).
868
+
869
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
870
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
871
+ information on the default strategy.
872
+
873
+ - 1 indicates the head is **not masked**,
874
+ - 0 indicates the head is **masked**.
875
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
876
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
877
+ config.n_positions - 1]`.
878
+
879
+ [What are position IDs?](../glossary#position-ids)
880
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
881
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
882
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
883
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
884
+
885
+ Two formats are allowed:
886
+ - a [`~cache_utils.Cache`] instance;
887
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
888
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
889
+ cache format.
890
+
891
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
892
+ legacy cache format will be returned.
893
+
894
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
895
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
896
+ of shape `(batch_size, sequence_length)`.
897
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
898
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
899
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
900
+ model's internal embedding lookup matrix.
901
+ use_cache (`bool`, *optional*):
902
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
903
+ `past_key_values`).
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
906
+ tensors for more detail.
907
+ output_hidden_states (`bool`, *optional*):
908
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
909
+ more detail.
910
+ return_dict (`bool`, *optional*):
911
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
912
+ """
913
+
914
+
915
+ @add_start_docstrings(
916
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
917
+ QWEN2_START_DOCSTRING,
918
+ )
919
+ class HyperQwen2Model(Qwen2PreTrainedModel):
920
+ """
921
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
922
+
923
+ Args:
924
+ config: HyperQwen2Config
925
+ """
926
+
927
+ def __init__(self, config: HyperQwen2Config):
928
+ super().__init__(config)
929
+ self.padding_idx = config.pad_token_id
930
+ self.vocab_size = config.vocab_size
931
+
932
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
933
+ self.layers = nn.ModuleList(
934
+ [HyperQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
935
+ )
936
+ self._attn_implementation = config._attn_implementation
937
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
938
+
939
+ self.gradient_checkpointing = False
940
+ # Initialize weights and apply final processing
941
+ self.post_init()
942
+
943
+ def get_input_embeddings(self):
944
+ return self.embed_tokens
945
+
946
+ def set_input_embeddings(self, value):
947
+ self.embed_tokens = value
948
+
949
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
950
+ def forward(
951
+ self,
952
+ input_ids: torch.LongTensor = None,
953
+ attention_mask: Optional[torch.Tensor] = None,
954
+ position_ids: Optional[torch.LongTensor] = None,
955
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
956
+ inputs_embeds: Optional[torch.FloatTensor] = None,
957
+ image_embeds=None,
958
+ media_offset=None,
959
+ use_cache: Optional[bool] = None,
960
+ output_attentions: Optional[bool] = None,
961
+ output_hidden_states: Optional[bool] = None,
962
+ return_dict: Optional[bool] = None,
963
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
964
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
965
+ output_hidden_states = (
966
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
967
+ )
968
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
969
+
970
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
971
+
972
+ # retrieve input_ids and inputs_embeds
973
+ if input_ids is not None and inputs_embeds is not None:
974
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
975
+ elif input_ids is not None:
976
+ batch_size, seq_length = input_ids.shape
977
+ elif inputs_embeds is not None:
978
+ batch_size, seq_length, _ = inputs_embeds.shape
979
+ else:
980
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
981
+
982
+ if self.gradient_checkpointing and self.training:
983
+ if use_cache:
984
+ logger.warning_once(
985
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
986
+ )
987
+ use_cache = False
988
+
989
+ past_key_values_length = 0
990
+
991
+ if use_cache:
992
+ use_legacy_cache = not isinstance(past_key_values, Cache)
993
+ if use_legacy_cache:
994
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
995
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
996
+
997
+ if position_ids is None:
998
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
999
+ position_ids = torch.arange(
1000
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1001
+ )
1002
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1003
+ else:
1004
+ position_ids = position_ids.view(-1, seq_length).long()
1005
+
1006
+ if inputs_embeds is None:
1007
+ inputs_embeds = self.embed_tokens(input_ids)
1008
+
1009
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1010
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1011
+ if is_padding_right:
1012
+ raise ValueError(
1013
+ "You are attempting to perform batched generation with padding_side='right'"
1014
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1015
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1016
+ )
1017
+
1018
+ if self._attn_implementation == "flash_attention_2":
1019
+ # 2d mask is passed through the layers
1020
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1021
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1022
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1023
+ # the manual implementation that requires a 4D causal mask in all cases.
1024
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1025
+ attention_mask,
1026
+ (batch_size, seq_length),
1027
+ inputs_embeds,
1028
+ past_key_values_length,
1029
+ sliding_window=self.config.sliding_window,
1030
+ )
1031
+ else:
1032
+ # 4d mask is passed through the layers
1033
+ attention_mask = _prepare_4d_causal_attention_mask(
1034
+ attention_mask,
1035
+ (batch_size, seq_length),
1036
+ inputs_embeds,
1037
+ past_key_values_length,
1038
+ sliding_window=self.config.sliding_window,
1039
+ )
1040
+
1041
+ hidden_states = inputs_embeds
1042
+
1043
+ # beam search
1044
+ if batch_size != len(media_offset):
1045
+ # The model is performing beamsearch, repeat the visual content
1046
+ beam_factor = batch_size // len(media_offset)
1047
+ assert batch_size % len(media_offset) == 0
1048
+ media_offset = media_offset * beam_factor
1049
+ image_embeds = repeat(image_embeds, 'B L D -> (factor B) L D', factor=beam_factor)
1050
+
1051
+ # # Flex mask
1052
+
1053
+ # expected_q_len = hidden_states.shape[1]
1054
+ # expected_kv_len = expected_q_len
1055
+ # if past_key_value is not None:
1056
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, 1)
1057
+
1058
+ # length_each_img = image_embeds.shape[1]
1059
+ # expected_kv_len = [expected_kv_len+len(_)*length_each_img for _ in media_offset]
1060
+ # flex_mask_block = []
1061
+
1062
+ # decoder layers
1063
+ all_hidden_states = () if output_hidden_states else None
1064
+ all_self_attns = () if output_attentions else None
1065
+ next_decoder_cache = None
1066
+
1067
+ for decoder_layer in self.layers:
1068
+ if output_hidden_states:
1069
+ all_hidden_states += (hidden_states,)
1070
+ if self.gradient_checkpointing and self.training:
1071
+ layer_outputs = self._gradient_checkpointing_func(
1072
+ decoder_layer.__call__,
1073
+ hidden_states,
1074
+ attention_mask,
1075
+ position_ids,
1076
+ image_embeds,
1077
+ media_offset,
1078
+ past_key_values,
1079
+ output_attentions,
1080
+ use_cache,
1081
+ )
1082
+ else:
1083
+ layer_outputs = decoder_layer(
1084
+ hidden_states,
1085
+ attention_mask=attention_mask,
1086
+ position_ids=position_ids,
1087
+ image_embeds=image_embeds,
1088
+ media_offset=media_offset,
1089
+ past_key_value=past_key_values,
1090
+ output_attentions=output_attentions,
1091
+ use_cache=use_cache,
1092
+ )
1093
+
1094
+ hidden_states = layer_outputs[0]
1095
+
1096
+ if use_cache:
1097
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1098
+
1099
+ if output_attentions:
1100
+ all_self_attns += (layer_outputs[1],)
1101
+
1102
+ hidden_states = self.norm(hidden_states)
1103
+
1104
+ # add hidden states from the last decoder layer
1105
+ if output_hidden_states:
1106
+ all_hidden_states += (hidden_states,)
1107
+
1108
+ next_cache = None
1109
+ if use_cache:
1110
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1111
+
1112
+ if not return_dict:
1113
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1114
+ return BaseModelOutputWithPast(
1115
+ last_hidden_state=hidden_states,
1116
+ past_key_values=next_cache,
1117
+ hidden_states=all_hidden_states,
1118
+ attentions=all_self_attns,
1119
+ )
1120
+
1121
+
1122
+ class HyperQwen2ForCausalLM(Qwen2PreTrainedModel):
1123
+ _tied_weights_keys = ["lm_head.weight"]
1124
+
1125
+ def __init__(self, config):
1126
+ super().__init__(config)
1127
+ self.model = HyperQwen2Model(config)
1128
+ self.vocab_size = config.vocab_size
1129
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1130
+
1131
+ # Initialize weights and apply final processing
1132
+ self.post_init()
1133
+
1134
+ def get_input_embeddings(self):
1135
+ return self.model.embed_tokens
1136
+
1137
+ def set_input_embeddings(self, value):
1138
+ self.model.embed_tokens = value
1139
+
1140
+ def get_output_embeddings(self):
1141
+ return self.lm_head
1142
+
1143
+ def set_output_embeddings(self, new_embeddings):
1144
+ self.lm_head = new_embeddings
1145
+
1146
+ def set_decoder(self, decoder):
1147
+ self.model = decoder
1148
+
1149
+ def get_decoder(self):
1150
+ return self.model
1151
+
1152
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1153
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1154
+ def forward(
1155
+ self,
1156
+ input_ids: torch.LongTensor = None,
1157
+ attention_mask: Optional[torch.Tensor] = None,
1158
+ position_ids: Optional[torch.LongTensor] = None,
1159
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1160
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1161
+ image_embeds=None,
1162
+ media_offset=None,
1163
+ labels: Optional[torch.LongTensor] = None,
1164
+ use_cache: Optional[bool] = None,
1165
+ output_attentions: Optional[bool] = None,
1166
+ output_hidden_states: Optional[bool] = None,
1167
+ return_dict: Optional[bool] = None,
1168
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1169
+ r"""
1170
+ Args:
1171
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1172
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1173
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1174
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1175
+
1176
+ Returns:
1177
+
1178
+ Example:
1179
+
1180
+ ```python
1181
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1182
+
1183
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1184
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1185
+
1186
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1187
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1188
+
1189
+ >>> # Generate
1190
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1191
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1192
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1193
+ ```"""
1194
+
1195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1196
+ output_hidden_states = (
1197
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1198
+ )
1199
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1200
+
1201
+ # # media_offset to 3.5 format
1202
+ # if media_offset is not None:
1203
+ # bs = media_offset.shape[0]
1204
+ # pad_media_offset = torch.cat([torch.zeros(media_offset.shape[0], 1,device=media_offset.device, dtype=media_offset.dtype), media_offset], dim=1)
1205
+ # pad_media_offset = (pad_media_offset[:,1:] - pad_media_offset[:,:-1]).nonzero()
1206
+ # media_offset = [[] for bi in range(bs)]
1207
+ # for i, (bi, li) in enumerate(pad_media_offset):
1208
+ # media_offset[bi].append(li)
1209
+
1210
+
1211
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1212
+ outputs = self.model(
1213
+ input_ids=input_ids,
1214
+ attention_mask=attention_mask,
1215
+ position_ids=position_ids,
1216
+ past_key_values=past_key_values,
1217
+ inputs_embeds=inputs_embeds,
1218
+ image_embeds=image_embeds,
1219
+ media_offset=media_offset,
1220
+ use_cache=use_cache,
1221
+ output_attentions=output_attentions,
1222
+ output_hidden_states=output_hidden_states,
1223
+ return_dict=return_dict,
1224
+ )
1225
+
1226
+ hidden_states = outputs[0]
1227
+ logits = self.lm_head(hidden_states)
1228
+ logits = logits.float()
1229
+
1230
+ loss = None
1231
+ if labels is not None:
1232
+ # Shift so that tokens < n predict n
1233
+ shift_logits = logits[..., :-1, :].contiguous()
1234
+ shift_labels = labels[..., 1:].contiguous()
1235
+ # Flatten the tokens
1236
+ loss_fct = CrossEntropyLoss()
1237
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1238
+ shift_labels = shift_labels.view(-1)
1239
+ # Enable model parallelism
1240
+ shift_labels = shift_labels.to(shift_logits.device)
1241
+ loss = loss_fct(shift_logits, shift_labels)
1242
+
1243
+ if not return_dict:
1244
+ output = (logits,) + outputs[1:]
1245
+ return (loss,) + output if loss is not None else output
1246
+
1247
+ return CausalLMOutputWithPast(
1248
+ loss=loss,
1249
+ logits=logits,
1250
+ past_key_values=outputs.past_key_values,
1251
+ hidden_states=outputs.hidden_states,
1252
+ attentions=outputs.attentions,
1253
+ )
1254
+
1255
+ def prepare_inputs_for_generation(
1256
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1257
+ ):
1258
+ # Omit tokens covered by past_key_values
1259
+ if past_key_values is not None:
1260
+ if isinstance(past_key_values, Cache):
1261
+ cache_length = past_key_values.get_seq_length()
1262
+ past_length = past_key_values.seen_tokens
1263
+ max_cache_length = past_key_values.get_max_length()
1264
+ else:
1265
+ cache_length = past_length = past_key_values[0][0].shape[2]
1266
+ max_cache_length = None
1267
+
1268
+ # Keep only the unprocessed tokens:
1269
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1270
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1271
+ # input)
1272
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1273
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1274
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1275
+ # input_ids based on the past_length.
1276
+ elif past_length < input_ids.shape[1]:
1277
+ input_ids = input_ids[:, past_length:]
1278
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1279
+
1280
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1281
+ if (
1282
+ max_cache_length is not None
1283
+ and attention_mask is not None
1284
+ and cache_length + input_ids.shape[1] > max_cache_length
1285
+ ):
1286
+ attention_mask = attention_mask[:, -max_cache_length:]
1287
+
1288
+ position_ids = kwargs.get("position_ids", None)
1289
+ if attention_mask is not None and position_ids is None:
1290
+ # create position_ids on the fly for batch generation
1291
+ position_ids = attention_mask.long().cumsum(-1) - 1
1292
+ position_ids.masked_fill_(attention_mask == 0, 1)
1293
+ if past_key_values:
1294
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1295
+
1296
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1297
+ if inputs_embeds is not None and past_key_values is None:
1298
+ model_inputs = {"inputs_embeds": inputs_embeds}
1299
+ else:
1300
+ model_inputs = {"input_ids": input_ids}
1301
+
1302
+ model_inputs.update(
1303
+ {
1304
+ "position_ids": position_ids,
1305
+ "past_key_values": past_key_values,
1306
+ "use_cache": kwargs.get("use_cache"),
1307
+ "attention_mask": attention_mask,
1308
+ 'image_embeds': kwargs.get('image_embeds'),
1309
+ 'media_offset': kwargs.get('media_offset'),
1310
+ }
1311
+ )
1312
+ return model_inputs
1313
+
1314
+ @staticmethod
1315
+ def _reorder_cache(past_key_values, beam_idx):
1316
+ reordered_past = ()
1317
+ for layer_past in past_key_values:
1318
+ reordered_past += (
1319
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1320
+ )
1321
+ return reordered_past
modeling_mplugowl3.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+ import json
4
+ import torch
5
+ import torchvision
6
+
7
+ from threading import Thread
8
+ from copy import deepcopy
9
+ from PIL import Image
10
+ from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
11
+ from .processing_mplugowl3 import mPLUGOwl3Processor
12
+ from .image_processing_mplugowl3 import mPLUGOwl3ImageProcessor
13
+ from .configuration_mplugowl3 import mPLUGOwl3Config
14
+ # from .modeling_navit_siglip import SiglipVisionTransformer
15
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
16
+ from .modeling_hyper_qwen2 import HyperQwen2ForCausalLM
17
+ from torch import nn
18
+
19
+
20
+ class mPLUGOwl3PreTrainedModel(Qwen2PreTrainedModel):
21
+ config_class = mPLUGOwl3Config
22
+ _no_split_modules = ["HyperQwen2DecoderLayer", "SiglipVisionTransformer"]
23
+
24
+
25
+ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
26
+ def __init__(self, config):
27
+ super().__init__(config)
28
+
29
+ self.vision_model = self.init_vision_module()
30
+ self.vision_dim = self.vision_model.embed_dim
31
+ self.embed_dim = self.config.hidden_size
32
+ self.vision2text_model = nn.Sequential(
33
+ nn.Linear(self.vision_dim, self.embed_dim),
34
+ nn.GELU(),
35
+ nn.Linear(self.embed_dim, self.embed_dim)
36
+ )
37
+ self.language_model = HyperQwen2ForCausalLM(config)
38
+
39
+
40
+
41
+ self.processor = None
42
+
43
+ self.terminators = ['<|im_end|>', '<|endoftext|>']
44
+ self.vision_batch_size = config.vision_batch_size
45
+
46
+ def init_vision_module(self):
47
+
48
+ self.config.vision_config._attn_implementation = self.config.vision_config._attn_implementation
49
+ model = SiglipVisionTransformer(self.config.vision_config)
50
+
51
+ setattr(model, 'embed_dim', model.embeddings.embed_dim)
52
+ setattr(model, 'patch_size', model.embeddings.patch_size)
53
+ return model
54
+
55
+
56
+ def get_input_embeddings(self):
57
+ return self.language_model.get_input_embeddings()
58
+
59
+ def set_input_embeddings(self, value):
60
+ self.language_model.embed_tokens = value
61
+
62
+ def get_output_embeddings(self):
63
+ return self.language_model.lm_head
64
+
65
+ def set_output_embeddings(self, new_embeddings):
66
+ self.language_model.lm_head = new_embeddings
67
+
68
+ def set_decoder(self, decoder):
69
+ self.language_model = decoder
70
+
71
+ def get_decoder(self):
72
+ return self.language_model
73
+
74
+ def _small_batched_forward(self, pixel_values):
75
+ vision_batch_size = self.vision_batch_size
76
+ image_forward_out = []
77
+ B = len(pixel_values)
78
+ for i in range(0, B, vision_batch_size):
79
+ start_idx = i
80
+ end_idx = min(B, i + vision_batch_size)
81
+ tmp_hs = self.vision_model(pixel_values[start_idx:end_idx], output_hidden_states=True).hidden_states[-2]
82
+ image_forward_out.append(tmp_hs)
83
+ vision_embedding = torch.cat(image_forward_out, dim=0)
84
+ assert vision_embedding.shape[0] == B
85
+ return vision_embedding
86
+
87
+ def forward_image(self, pixel_values):
88
+ if pixel_values is None:
89
+ return None
90
+ dtype = self.language_model.model.embed_tokens.weight.dtype
91
+ with torch.inference_mode():
92
+ image_embeds = self._small_batched_forward(pixel_values.to(dtype))
93
+ # image_embeds = self.vision_model(pixel_values.to(dtype), output_hidden_states=True).hidden_states[-2]
94
+
95
+ if self.vision2text_model is not None:
96
+ image_embeds = self.vision2text_model(image_embeds)
97
+ else:
98
+ pass
99
+
100
+ return image_embeds
101
+
102
+ def forward(self, pixel_values=None, **kwargs):
103
+ image_embeds = self.forward_image(pixel_values)
104
+
105
+ return self.language_model(
106
+ image_embeds=image_embeds,
107
+ **kwargs
108
+ )
109
+
110
+ def _decode(self, input_ids, image_embeds, media_offset, tokenizer, attention_mask, decode_text=False, **kwargs):
111
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
112
+ output = self.language_model.generate(
113
+ input_ids=input_ids,
114
+ image_embeds=image_embeds,
115
+ media_offset=media_offset,
116
+ pad_token_id=0,
117
+ eos_token_id=terminators,
118
+ attention_mask=attention_mask,
119
+ **kwargs
120
+ )
121
+
122
+ output = output[:,input_ids.shape[1]:]
123
+ print(output)
124
+ if decode_text:
125
+ return self._decode_text(output, tokenizer)
126
+ return output
127
+
128
+ def _decode_stream(self, input_ids, image_embeds, media_offset, tokenizer, **kwargs):
129
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
130
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
131
+ generation_kwargs = {
132
+ 'input_ids': input_ids,
133
+ 'image_embeds': image_embeds,
134
+ 'media_offset': media_offset,
135
+ 'pad_token_id': 0,
136
+ 'eos_token_id': terminators,
137
+ 'streamer': streamer
138
+ }
139
+ generation_kwargs.update(kwargs)
140
+
141
+ thread = Thread(target=self.language_model.generate, kwargs=generation_kwargs)
142
+ thread.start()
143
+
144
+ return streamer
145
+
146
+ def _decode_text(self, result_ids, tokenizer):
147
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
148
+ result_text = []
149
+ for result in result_ids:
150
+ result = result[result != 0]
151
+ if result[-1] in terminators:
152
+ result = result[:-1]
153
+ result_text.append(tokenizer.decode(result).strip())
154
+ return result_text
155
+
156
+ def init_processor(self, tokenizer):
157
+ ip = mPLUGOwl3ImageProcessor(image_size=378)
158
+ self.processor = mPLUGOwl3Processor(image_processor=ip, tokenizer=tokenizer)
159
+ processor = self.processor
160
+ return processor
161
+
162
+ def generate(
163
+ self,
164
+ input_ids=None,
165
+ pixel_values=None,
166
+ media_offset=None,
167
+ attention_mask=None,
168
+ tokenizer=None,
169
+ stream=False,
170
+ decode_text=False,
171
+ **kwargs
172
+ ):
173
+ assert input_ids is not None
174
+
175
+ with torch.inference_mode():
176
+ image_embeds = self.forward_image(pixel_values)
177
+
178
+ if stream:
179
+ result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
180
+ else:
181
+ result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
182
+
183
+ return result
184
+
185
+ def chat(
186
+ self,
187
+ images,
188
+ videos,
189
+ messages,
190
+ tokenizer,
191
+ processor=None,
192
+ max_new_tokens=2048,
193
+ min_new_tokens=0,
194
+ sampling=True,
195
+ max_inp_length=8192,
196
+ system_prompt='',
197
+ stream=False,
198
+ max_slice_nums=None,
199
+ use_image_id=None,
200
+ **kwargs
201
+ ):
202
+ cut_flag = kwargs.get('kwargs', True)
203
+ if processor is None:
204
+ if self.processor is None:
205
+ processor = self.init_processor(tokenizer)
206
+ else:
207
+ processor = self.processor
208
+ inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag)
209
+ inputs.to('cuda')
210
+ inputs.update({
211
+ 'tokenizer': tokenizer,
212
+ 'max_new_tokens': max_new_tokens,
213
+ # 'stream':True,
214
+ })
215
+ if sampling:
216
+ generation_config = {
217
+ "top_p": 0.8,
218
+ "top_k": 100,
219
+ "temperature": 0.7,
220
+ "do_sample": True,
221
+ # "repetition_penalty": 1.05
222
+ }
223
+ else:
224
+ generation_config = {
225
+ "num_beams": 3,
226
+ # "repetition_penalty": 1.2,
227
+ }
228
+
229
+ if min_new_tokens > 0:
230
+ generation_config['min_new_tokens'] = min_new_tokens
231
+
232
+ generation_config.update(
233
+ (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
234
+ )
235
+ with torch.inference_mode():
236
+ res = self.generate(
237
+ **inputs,
238
+ stream=stream,
239
+ decode_text=True,
240
+ **generation_config
241
+ )
242
+
243
+ if stream:
244
+ def stream_gen():
245
+ for text in res:
246
+ for term in self.terminators:
247
+ text = text.replace(term, '')
248
+ yield text
249
+ return stream_gen()
250
+
251
+ else:
252
+ answer = res[0]
253
+ return answer
254
+
processing_mplugowl3.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for mPLUGOwl3.
17
+ """
18
+
19
+ from typing import List, Optional, Union, Dict, Any
20
+ import warnings
21
+ import torch
22
+ import re
23
+
24
+ from transformers.image_processing_utils import BatchFeature
25
+ from transformers.image_utils import ImageInput
26
+ from transformers.processing_utils import ProcessorMixin
27
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
28
+ from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
29
+ from icecream import ic
30
+ from .image_processing_mplugowl3 import mPLUGOwl3BatchFeature, mPLUGOwl3ImageProcessor
31
+
32
+ OWL_MEDIA_TOKEN=['<|image|>']
33
+
34
+ class MediaIndicesHelper():
35
+ def __init__(self, tokenizer) -> None:
36
+ self.media_position = []
37
+ self.tokenizer = tokenizer
38
+
39
+
40
+ def has_media(self, text, media_tokens=None):
41
+ if media_tokens is None:
42
+ media_tokens = OWL_MEDIA_TOKEN
43
+ has_media_flag = any([media_token == text for media_token in media_tokens])
44
+ if any([media_token in text for media_token in media_tokens]):
45
+ # 不允许出现text中包含media token但是不仅仅是media token。 media token必须单独为一个chunk
46
+ assert has_media_flag, text
47
+ return has_media_flag
48
+
49
+ def add_media(self, text_chunk, text=None, tokenize_fn=None):
50
+
51
+ # cross
52
+ assert tokenize_fn is not None
53
+ assert text is not None
54
+ assert text in OWL_MEDIA_TOKEN
55
+ media_token_ids = tokenize_fn(text)
56
+ start = len(text_chunk)
57
+ end = start + len(media_token_ids)
58
+ self.media_position.append([start, end])
59
+ text_chunk.extend(media_token_ids)
60
+ return len(media_token_ids)
61
+
62
+ def cal_media_offset(self, input_ids):
63
+ if len(self.media_position) == 0:
64
+ return torch.ones_like(input_ids)*(-1000000)
65
+
66
+ media_starts = torch.tensor([_[0] for _ in self.media_position]).reshape(1,-1)
67
+ rng = torch.arange(input_ids.shape[0]).reshape(-1,1)
68
+ matrix = (rng > media_starts).sum(dim=1)
69
+
70
+ return matrix
71
+
72
+ def len_images(self,):
73
+ return len(self.media_position)
74
+
75
+ class mPLUGOwl3Processor(ProcessorMixin):
76
+ r"""
77
+ Args:
78
+ image_processor ([`mPLUGOwl3ImageProcessor`], *optional*):
79
+ The image processor is a required input.
80
+ tokenizer ([`LlamaTokenizerWrapper`], *optional*):
81
+ The tokenizer is a required input.
82
+ """
83
+ attributes = ["image_processor", "tokenizer"]
84
+ image_processor_class = "AutoImageProcessor"
85
+ tokenizer_class = "AutoTokenizer"
86
+
87
+ def __init__(self, image_processor: mPLUGOwl3ImageProcessor = None, tokenizer=None, prompt_style='chatml', inference_mode=True, addition_eod="<|endoftext|>"):
88
+ super().__init__(image_processor, tokenizer)
89
+ self.image_processor: mPLUGOwl3ImageProcessor
90
+ self.prompt_style = prompt_style
91
+ self.inference_mode = inference_mode
92
+ self.media_tokens = ["<|image|>"]
93
+ self.addition_eod = addition_eod
94
+
95
+ def build_text_qwen(self, messages):
96
+ # role should be within ['system', 'user', 'assistant']
97
+ im_start, im_end = '<|im_start|>', '<|im_end|>'
98
+
99
+ text = []
100
+ for num_turn, message in enumerate(messages):
101
+ if num_turn == 0 and message['role'] != 'system':
102
+ if self.prompt_style != 'plain':
103
+ text.append({
104
+ "text": f"{im_start}system\n{im_end}",
105
+ "label": 0
106
+ })
107
+ if message['role'] == 'system':
108
+ if self.prompt_style != 'plain':
109
+ text.append({
110
+ "text": f"{im_start}system\n{message['content']}{im_end}",
111
+ "label": 0
112
+ })
113
+ elif message['role'] == 'user':
114
+ if self.prompt_style != 'plain':
115
+ content = f"\n{im_start}user\n{message['content']}{im_end}"
116
+ else:
117
+ content = message['content']
118
+ pattern = '|'.join(map(re.escape, self.media_tokens))
119
+ chunk_strs = re.split(f'({pattern})', content)
120
+ for chunk_str in chunk_strs:
121
+ text.append({
122
+ "text": chunk_str,
123
+ "label": 0
124
+ })
125
+
126
+ elif message['role'] == 'assistant':
127
+ if self.prompt_style != 'plain':
128
+ text.append({"text": f"\n{im_start}assistant\n", "label": 0})
129
+ text.append({"text": f"{message['content']}{im_end}", "label": 1})
130
+ else:
131
+ text.append({"text": f"{message['content']}", "label": 1})
132
+ text.append({"text": self.addition_eod, "label": 1})
133
+ else:
134
+ raise NotImplementedError
135
+ if self.inference_mode:
136
+ while text and text[-1]['label']==1: # 只要列表非空且最后一个元素满足条件
137
+ text.pop() # 就移除最后一个元素
138
+ return text
139
+
140
+ def wrapped_tokenize(self, text):
141
+ return self.tokenizer(text).input_ids
142
+
143
+ def encode_text_sft(self, texts):
144
+ # output enc_chunk
145
+
146
+ enc_chunk = []
147
+ label_chunk = []
148
+ enc_length = 0
149
+
150
+ num_images = 0
151
+
152
+ media_helper = MediaIndicesHelper(tokenizer=self.tokenizer)
153
+ for current_ti, text_chunk in enumerate(texts):
154
+
155
+ text = text_chunk["text"]
156
+ label = text_chunk["label"]
157
+
158
+ if not media_helper.has_media(text):
159
+ curr_chunk=self.wrapped_tokenize(text)
160
+ if label == 1:
161
+ enc_length += len(curr_chunk)
162
+ enc_chunk += curr_chunk
163
+ label_chunk += [label] * len(curr_chunk)
164
+ else:
165
+
166
+ enc_length += len(curr_chunk)
167
+ enc_chunk += curr_chunk
168
+ label_chunk += [label] * len(curr_chunk)
169
+ # For media tokens
170
+ else:
171
+
172
+ add_length = media_helper.add_media(
173
+ enc_chunk,
174
+ text=text,
175
+ tokenize_fn=self.wrapped_tokenize)
176
+ enc_length += add_length
177
+ label_chunk += [label] * add_length
178
+ # enc_chunk.extend([self.media_tokens[text]] * self.media_lengths[text])
179
+ # enc_length += self.media_lengths[text]
180
+ # label_chunk += [label] * self.media_lengths[text]
181
+ num_images += 1
182
+
183
+ enc_chunk = torch.tensor(enc_chunk).long()
184
+ # media_offset = []
185
+ # media_before = 0
186
+ # for i,_ in enumerate([media_helper]):
187
+ # mo = _.cal_media_offset(enc_chunk)
188
+ # media_offset.append(torch.cat([(torch.ones(mo.shape[0],1)*media_before).long().to(mo.device), (mo+media_before).unsqueeze(1)], dim=1)) # L 2
189
+
190
+ # media_before += _.len_images()
191
+ # media_offset = torch.stack(media_offset, dim=0)
192
+ media_offset = [torch.tensor([_[0] for _ in media_helper.media_position]).long()]
193
+ return {
194
+ 'input_ids': enc_chunk.unsqueeze(0),
195
+ 'media_offset': media_offset,
196
+ }
197
+
198
+
199
+ def __call__(
200
+ self,
201
+ messages,
202
+ images = None,
203
+ videos = None,
204
+ max_length: Optional[int] = None,
205
+ cut_enable=True,
206
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
207
+ **kwargs
208
+ ) -> mPLUGOwl3BatchFeature:
209
+ medias = []
210
+ if videos is not None:
211
+ medias.extend([{'type': 'video', 'content': video, 'use_video_span': True} for video in videos])
212
+ if images is not None:
213
+ medias.extend([{'type':'image', 'content': image} for image in images])
214
+
215
+ if len(medias):
216
+ image_tensor_list = []
217
+ pattern = r"(<\|image\|>|<\|video\|>)"
218
+ # 存在媒体
219
+ image_token_ptr = 0
220
+ media_layout = []
221
+ for message in messages:
222
+ text_list = re.split(pattern, message['content'])
223
+ text = ''
224
+ for text_content in text_list:
225
+ if text_content in ['<|image|>', '<|video|>']:
226
+ media_item = medias[image_token_ptr]
227
+ image_token_ptr += 1
228
+ if text_content == '<|image|>':
229
+ assert media_item['type'] == 'image'
230
+ image = media_item['content']
231
+
232
+ image_inputs = self.image_processor([image], cut_enable=cut_enable, return_tensors=return_tensors)
233
+ if image_inputs.get('cut_shape',None) is not None:
234
+ cut_shape = image_inputs['cut_shape']
235
+ cut_text = self.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0][0], w=cut_shape[0][1])
236
+ text += cut_text
237
+ image_tensor_list.append(image_inputs['pixel_values'])
238
+ else:
239
+ text += text_content
240
+ image_tensor_list.append(image_inputs['pixel_values'])
241
+ elif text_content == '<|video|>':
242
+ assert media_item['type'] == 'video'
243
+ video = media_item['content']
244
+ use_video_span = media_item['use_video_span']
245
+ image_tensor = self.image_processor(video, cut_enable=False)['pixel_values']
246
+ image_tensor_list.append(image_tensor)
247
+ num_video_frame = image_tensor.shape[0]
248
+ if use_video_span:
249
+ text_content = '<|start_video_frame|>'+'<|image|>'*num_video_frame+'<|end_video_frame|>'
250
+ else:
251
+ text_content = '<|image|>'*num_video_frame
252
+ text += text_content
253
+ else:
254
+ text += text_content
255
+ message['content'] = text
256
+ assert image_token_ptr == len(medias), (image_token_ptr,len(medias)) # 保证图和token数目一致
257
+ assert all(len(_.shape) == 4 for _ in image_tensor_list), [_.shape for _ in image_tensor_list]
258
+ num_image_tokens = sum([_['content'].count('<|image|>')for _ in messages])
259
+ num_image_shapes = sum([_.shape[0] for _ in image_tensor_list])
260
+ assert num_image_tokens == num_image_shapes, (messages, [_.shape for _ in image_tensor_list])
261
+
262
+ image_tensor_list = torch.cat(image_tensor_list, dim=0)
263
+
264
+ text = self.build_text_qwen(messages)
265
+ model_inputs = self.encode_text_sft(text)
266
+
267
+ if len(medias) is not None:
268
+ model_inputs.update({'pixel_values': image_tensor_list})
269
+ # if 'cut_shape' in model_inputs:
270
+ # model_inputs.pop('cut_shape')
271
+ # if 'cut_shape_indices' in model_inputs:
272
+ # model_inputs.pop('cut_shape_indices')
273
+ return mPLUGOwl3BatchFeature(model_inputs)
274
+
275
+ def check_media(self, images, messages):
276
+ media_num = 0 if images is None else len(images)
277
+ media_count = sum([message['content'].count('<|image|>') for message in messages])
278
+ assert media_num == media_count
279
+
280
+
281
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
282
+ def batch_decode(self, *args, **kwargs):
283
+ """
284
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
285
+ refer to the docstring of this method for more information.
286
+ """
287
+ output_ids = args[0]
288
+ result_text = []
289
+ for result in output_ids:
290
+ result = result[result != 0]
291
+ if result[0] == self.tokenizer.bos_id:
292
+ result = result[1:]
293
+ if result[-1] == self.tokenizer.eos_id:
294
+ result = result[:-1]
295
+ result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
296
+ return result_text
297
+ # return self.tokenizer.batch_decode(*args, **kwargs)
298
+
299
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
300
+ def decode(self, *args, **kwargs):
301
+ """
302
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
303
+ the docstring of this method for more information.
304
+ """
305
+ result = args[0]
306
+ result = result[result != 0]
307
+ if result[0] == self.tokenizer.bos_id:
308
+ result = result[1:]
309
+ if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
310
+ result = result[:-1]
311
+ return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
312
+
313
+ def _convert(
314
+ self, input_str, max_inp_length: Optional[int] = None
315
+ ):
316
+ if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
317
+ input_ids = self.tokenizer.encode(input_str)
318
+ else:
319
+ input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str)
320
+ if max_inp_length is not None:
321
+ input_ids = input_ids[:max_inp_length]
322
+ input_ids = torch.tensor(input_ids, dtype=torch.int32)
323
+
324
+ start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
325
+ end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
326
+
327
+ image_start_tokens = torch.where(start_cond)[0]
328
+ image_start_tokens += 1
329
+ image_end_tokens = torch.where(end_cond)[0]
330
+
331
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
332
+
333
+ image_bounds = torch.hstack(
334
+ [
335
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
336
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
337
+ ]
338
+ )
339
+ return input_ids, image_bounds
340
+
341
+
342
+ @property
343
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
344
+ def model_input_names(self):
345
+ tokenizer_input_names = self.tokenizer.model_input_names
346
+ image_processor_input_names = self.image_processor.model_input_names
347
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
348
+
349
+
350
+ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
351
+ items = []
352
+ if isinstance(inputs[0], list):
353
+ assert isinstance(inputs[0][0], torch.Tensor)
354
+ for it in inputs:
355
+ for tr in it:
356
+ items.append(tr)
357
+ else:
358
+ assert isinstance(inputs[0], torch.Tensor)
359
+ items = inputs
360
+
361
+ batch_size = len(items)
362
+ shape = items[0].shape
363
+ dim = len(shape)
364
+ assert dim <= 2
365
+ if max_length is None:
366
+ max_length = 0
367
+ max_length = max(max_length, max(item.shape[-1] for item in items))
368
+ min_length = min(item.shape[-1] for item in items)
369
+ dtype = items[0].dtype
370
+
371
+ if dim == 0:
372
+ return torch.stack([item for item in items], dim=0), [0]
373
+ elif dim == 1:
374
+ if max_length == min_length:
375
+ return torch.stack([item for item in items], dim=0), [0] * batch_size
376
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
377
+ else:
378
+ tensor = (
379
+ torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
380
+ + padding_value
381
+ )
382
+
383
+ padding_length = []
384
+ for i, item in enumerate(items):
385
+ if dim == 1:
386
+ if padding_side == "left":
387
+ tensor[i, -len(item) :] = item.clone()
388
+ else:
389
+ tensor[i, : len(item)] = item.clone()
390
+ elif dim == 2:
391
+ if padding_side == "left":
392
+ tensor[i, -len(item) :, :] = item.clone()
393
+ else:
394
+ tensor[i, : len(item), :] = item.clone()
395
+ padding_length.append(tensor.shape[-1] - len(item))
396
+
397
+ return tensor, padding_length
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
30
+ "bos_token": null,
31
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "<|im_end|>",
34
+ "errors": "replace",
35
+ "model_max_length": 131072,
36
+ "pad_token": "<|endoftext|>",
37
+ "split_special_tokens": false,
38
+ "tokenizer_class": "Qwen2Tokenizer",
39
+ "unk_token": null
40
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff