pyx9913 commited on
Commit
aa60bbf
1 Parent(s): 4f1e38f

feat: 🎸 add chat model code

Browse files
README.md CHANGED
@@ -1,58 +1,45 @@
1
- ---
2
- language:
3
- - en
4
- - zh
5
- ---
6
- <div align="center">
7
-
8
- **VisCPM**
9
-
10
- **Chinese-English bilingual multi-modal large model series based on CPM (Chinese Pretrained Models) basic model**
11
 
12
  <p align="center">
13
- <a href="https://github.com/OpenBMB/VisCPM">Github</a> •
14
- <a href="https://huggingface.co/openbmb/VisCPM-Paint">VisCPM-Paint</a>
 
15
  </p>
16
 
17
- </div>
18
 
19
- `VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. VisCPM is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (Q-Former) and visual decoder (Diffusion-UNet) to support visual inputs and outputs. Thanks to the good bilingual capability of CPM-Bee, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
20
-
21
- - **👐 Open-source Usage**: VisCPM is free to be used for personal and research purposes. By open-sourcing the VisCPM model family, we hope to promote the development of the open-source community of large multimodal models and related research.
22
- - **🌟 Image and text generation coverage**: VisCPM models provide relatively comprehensive support for image and text multimodal capabilities, covering both multimodal conversation (image-to-text generation) capabilities and text-to-image generation capabilities.
23
- - **💫 Excellent bilingual performance**: Thanks to the excellent bilingual capability of the base language model CPM-Bee, VisCPM achieves outstanding results in both bilingual multimodal conversation and text-to-image generation.
24
 
25
  ## VisCPM-Chat
26
- `VisCPM-Chat` supports bilingual multimodal conversations involving images in both Chinese and English. The model utilizes `Q-Former` as the visual encoder and `CPM-Bee` (10B) as the base LLM. It combines visual and language models and is optimized with the language modeling training objective. The model training consists of two stages: pretraining and instruction tuning.
27
 
28
- * Pretraining: `VisCPM-Chat` is pretrained using approximately 100M high-quality English text-image pairs. The data sources include CC3M, CC12M, COCO, Visual Genome, Laion, etc. In this stage, the language model parameters remain fixed, and only the parameters of the `Q-Former` are updated to enable efficient alignment of vision and language representations.
29
 
30
- * Instruction Tuning: We utilize the [LLaVA-150K](https://llava-vl.github.io/) dataset that contains English multimodal instruction-following data. We mix this data with corresponding translated Chinese data to fine-tune the model and align its multimodal capabilities with user intents. In this stage, we update all model parameters to improve the data efficiency of instruction tuning. Interestingly, we observe that even when using only English instruction data for fine-tuning, the model can well comprehend Chinese questions but can only respond in English. This indicates that the model has achieved good generalization in terms of its multilingual and multimodal capabilities. By incorporating a small amount of translated Chinese data during the instruction tuning stage, we can align the model's response language with the user's question language.
31
 
32
- We evaluate the model on the standard [LLaVA English test set](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and the translated [Chinese test set](data/translated_LLaVA_qa90) from the standard English test set. The evaluation benchmark examines the model's performance in conversation, detailed description, and complex reasoning, and uses GPT-4 for scoring. It can be observed that `VisCPM-Chat` achieves the best average performance in Chinese multimodal capabilities, excelling in conversation and complex reasoning, while also demonstrating good English multimodal capabilities. We provide two versions of the model, namely `VisCPM-Chat-balance` and `VisCPM-Chat-zhplus`. The former has a balanced ability in both English and Chinese, while the latter has a stronger emphasis on Chinese proficiency. Both models use the same data during the instruction tuning stage. `VisCPM-Chat-zhplus` additionally incorporates 20M cleaned native Chinese text-image pairs and 120M translated text-image pairs in Chinese during the pretraining stage.
33
 
34
  <table>
35
  <tr>
36
- <td align="center" rowspan="2" colspan="2">Model</td>
37
- <td align="center" rowspan="2">LLM Backbone</td>
38
- <td align="center" colspan="4">English</td>
39
- <td align="center" colspan="4">Chinese</td>
40
  </tr>
41
  <tr>
42
- <td align="center">Conversation</td>
43
- <td align="center">Detailed Description</td>
44
- <td align="center">Complex Reasoning</td>
45
- <td align="center">Avg</td>
46
- <td align="center">Conversation</td>
47
- <td align="center">Detailed Description</td>
48
- <td align="center">Complex Reasoning</td>
49
- <td align="center">Avg</td>
50
  </tr>
51
  <tr>
52
- <td align="center" rowspan="3">English Model</td>
53
  <td align="center">MiniGPT4</td>
54
- <td align="center">Vicuna-13B</td>
55
- <td align="center">65.0</td>
56
  <td align="center">67.3</td>
57
  <td align="center">76.6</td>
58
  <td align="center">69.7</td>
@@ -63,9 +50,8 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
63
  </tr>
64
  <tr>
65
  <td align="center">InstructBLIP</td>
66
- <td align="center">Vicuna-13B</td>
67
  <td align="center">81.9</td>
68
- <td align="center">68.0</td>
69
  <td align="center">91.2</td>
70
  <td align="center">80.5</td>
71
  <td align="center">-</td>
@@ -75,20 +61,18 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
75
  </tr>
76
  <tr>
77
  <td align="center">LLaVA</td>
78
- <td align="center">Vicuna-13B</td>
79
- <td align="center"><b>89.5</b></td>
80
- <td align="center"><b>70.4</b></td>
81
- <td align="center"><b>96.2</b></td>
82
- <td align="center"><b>85.6</b></td>
83
  <td align="center">-</td>
84
  <td align="center">-</td>
85
  <td align="center">-</td>
86
  <td align="center">-</td>
87
  </tr>
88
  <tr>
89
- <td align="center" rowspan="5">En-Zh Bilingual Model</td>
90
  <td align="center">mPLUG-Owl </td>
91
- <td align="center">LLaMA-7B</td>
92
  <td align="center">64.6</td>
93
  <td align="center">47.7</td>
94
  <td align="center">80.1</td>
@@ -96,61 +80,132 @@ We evaluate the model on the standard [LLaVA English test set](https://huggingfa
96
  <td align="center">76.3</td>
97
  <td align="center">61.2</td>
98
  <td align="center">77.8</td>
99
- <td align="center">72.0</td>
100
  </tr>
101
  <tr>
102
  <td align="center">VisualGLM</td>
103
- <td align="center">ChatGLM-6B</td>
104
  <td align="center">62.4</td>
105
- <td align="center">63.0</td>
106
  <td align="center">80.6</td>
107
  <td align="center">68.7</td>
108
  <td align="center">76.6</td>
109
- <td align="center"><b>87.8</b></td>
110
  <td align="center">83.6</td>
111
  <td align="center">82.7</td>
112
  </tr>
113
  <tr>
114
- <td align="center">Ziya-Visual </td>
115
- <td align="center">Ziya-LLaMA-13B-v1</td>
116
  <td align="center">82.7</td>
117
  <td align="center">69.9</td>
118
  <td align="center">92.1</td>
119
  <td align="center">81.7</td>
120
- <td align="center">85.0</td>
121
  <td align="center">74.7</td>
122
  <td align="center">82.4</td>
123
  <td align="center">80.8</td>
124
  </tr>
125
  <tr>
126
- <td align="center">VisCPM-Chat-balance</td>
127
- <td align="center">CPMBee-10B</td>
128
  <td align="center">83.3</td>
129
  <td align="center">68.9</td>
130
  <td align="center">90.5</td>
131
  <td align="center">81.1</td>
132
- <td align="center"><b>92.7</b></td>
133
  <td align="center">76.1</td>
134
  <td align="center">89.2</td>
135
  <td align="center">86.3</td>
136
  </tr>
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  <tr>
138
- <td align="center">VisCPM-Chat-zhplus</td>
139
- <td align="center">CPMBee-10B</td>
140
- <td align="center">80.1</td>
141
- <td align="center">65.7</td>
142
- <td align="center">92.5</td>
143
- <td align="center">79.6</td>
144
- <td align="center">90.3</td>
145
- <td align="center">81.4</td>
146
- <td align="center"><b>92.1</b></td>
147
- <td align="center"><b>88.2</b></td>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  </tr>
149
  </table>
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- ## 📝 License
 
 
153
 
154
- VisCPM is governed by the [GML License](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E9%9D%9E%E5%95%86%E4%B8%9A%E5%8C%96.md), and permits individual and research usages. If you intend to utilize the model for commercial purposes, please reach out to cpm@modelbest.cn to negotiate commercial licensing.
 
 
155
 
156
- The CPM-Bee base, governed by the [General Model License (GML)](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E5%95%86%E4%B8%9A%E6%8E%88%E6%9D%83.md), permits commercial usage. If you intend to utilize the model for commercial purposes, please reach out to cpm@modelbest.cn to obtain the certificate of authorization.
 
 
 
 
 
 
 
 
 
 
 
1
+ # VisCPM
2
+ 简体中文 | [English](README_en.md)
 
 
 
 
 
 
 
 
3
 
4
  <p align="center">
5
+ <p align="left">
6
+ <a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
7
+ <a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
8
  </p>
9
 
10
+ `VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (Q-Former) and visual decoder (Diffusion-UNet) to support visual inputs and outputs. Thanks to the good bilingual capability of CPM-Bee, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
11
 
12
+ `VisCPM`是一个开源的多模态大模型系列,支持中英双语的多模态对话能力(`VisCPM-Chat`模型)和文到图生成能力(`VisCPM-Paint`模型),在中文多模态开源模型中达到最佳水平。`VisCPM`基于百亿参数量语言大模型[CPM-Bee](https://github.com/OpenBMB/CPM-Bee)10B)训练,融合视觉编码器(`Q-Former`)和视觉解码器(`Diffusion-UNet`)以支持视觉信号的输入和输出。得益于`CPM-Bee`底座优秀的双语能力,`VisCPM`可以仅通过英文多模态数据预训练,泛化实现优秀的中文多模态能力。
 
 
 
 
13
 
14
  ## VisCPM-Chat
15
+ `VisCPM-Chat`支持面向图像进行中英双语多模态对话。该模型使用`Q-Former`作为视觉编码器,使用CPM-Bee10B)作为语言交互基底模型,并通过语言建模训练目标融合视觉和语言模型。模型训练包括预训练和指令精调两阶段:
16
 
17
+ * 预训练:我们使用约100M高质量英文图文对数据对`VisCPM-Chat`进行了预训练,数据包括CC3MCC12MCOCOVisual GenomeLaion等。在预训练阶段,语言模型参数保持固定,仅更新`Q-Former`部分参数,以支持大规模视觉-语言表示的高效对齐。
18
 
19
+ * 指令精调:我们采用[LLaVA-150K](https://llava-vl.github.io/)英文指令精调数据,并混合相应翻译后的中文数据对模型进行指令精调,以对齐模型多模态基础能力和用户使用意图。在指令精调阶段,我们更新全部模型参数,以提升指令精调数据的利用效率。有趣的是,我们发现即使仅采用英文指令数据进行指令精调,模型也可以理解中文问题,但仅能用英文回答。这表明模型的多语言多模态能力已经得到良好的泛化。在指令精调阶段进一步加入少量中文翻译数据,可以将模型回复语言和用户问题语言对齐。
20
 
21
+ 我们在LLaVA英文测试集和翻译的中文测试集对模型进行了评测,该评测基准考察模型在开放域对话、图像细节描述、复杂推理方面的表现,并使用GPT-4进行打分。可以观察到,`VisCPM-Chat`在中文多模态能力方面取得了最佳的平均性能,在通用域对话和复杂推理表现出色,同时也表现出了不错的英文多模态能力。
22
 
23
  <table>
24
  <tr>
25
+ <td align="center" rowspan="2" colspan="2">模型</td>
26
+ <td align="center" colspan="4">英文</td>
27
+ <td align="center" colspan="4">中文</td>
 
28
  </tr>
29
  <tr>
30
+ <td align="center">多模态对话</td>
31
+ <td align="center">细节描述</td>
32
+ <td align="center">复杂推理</td>
33
+ <td align="center">平均</td>
34
+ <td align="center">多模态对话</td>
35
+ <td align="center">细节描述</td>
36
+ <td align="center">复杂推理</td>
37
+ <td align="center">平均</td>
38
  </tr>
39
  <tr>
40
+ <td align="center" rowspan="3">英文模型</td>
41
  <td align="center">MiniGPT4</td>
42
+ <td align="center">65</td>
 
43
  <td align="center">67.3</td>
44
  <td align="center">76.6</td>
45
  <td align="center">69.7</td>
 
50
  </tr>
51
  <tr>
52
  <td align="center">InstructBLIP</td>
 
53
  <td align="center">81.9</td>
54
+ <td align="center">68</td>
55
  <td align="center">91.2</td>
56
  <td align="center">80.5</td>
57
  <td align="center">-</td>
 
61
  </tr>
62
  <tr>
63
  <td align="center">LLaVA</td>
64
+ <td align="center">89.5</td>
65
+ <td align="center">70.4</td>
66
+ <td align="center">96.2</td>
67
+ <td align="center">85.6</td>
 
68
  <td align="center">-</td>
69
  <td align="center">-</td>
70
  <td align="center">-</td>
71
  <td align="center">-</td>
72
  </tr>
73
  <tr>
74
+ <td align="center" rowspan="4">中英双语</td>
75
  <td align="center">mPLUG-Owl </td>
 
76
  <td align="center">64.6</td>
77
  <td align="center">47.7</td>
78
  <td align="center">80.1</td>
 
80
  <td align="center">76.3</td>
81
  <td align="center">61.2</td>
82
  <td align="center">77.8</td>
83
+ <td align="center">72</td>
84
  </tr>
85
  <tr>
86
  <td align="center">VisualGLM</td>
 
87
  <td align="center">62.4</td>
88
+ <td align="center">63</td>
89
  <td align="center">80.6</td>
90
  <td align="center">68.7</td>
91
  <td align="center">76.6</td>
92
+ <td align="center">87.8</td>
93
  <td align="center">83.6</td>
94
  <td align="center">82.7</td>
95
  </tr>
96
  <tr>
97
+ <td align="center">Ziya (LLaMA 13B)</td>
 
98
  <td align="center">82.7</td>
99
  <td align="center">69.9</td>
100
  <td align="center">92.1</td>
101
  <td align="center">81.7</td>
102
+ <td align="center">85</td>
103
  <td align="center">74.7</td>
104
  <td align="center">82.4</td>
105
  <td align="center">80.8</td>
106
  </tr>
107
  <tr>
108
+ <td align="center">VisCPM-Chat</td>
 
109
  <td align="center">83.3</td>
110
  <td align="center">68.9</td>
111
  <td align="center">90.5</td>
112
  <td align="center">81.1</td>
113
+ <td align="center">92.7</td>
114
  <td align="center">76.1</td>
115
  <td align="center">89.2</td>
116
  <td align="center">86.3</td>
117
  </tr>
118
+ </table>
119
+
120
+ ## VisCPM-Paint
121
+ `VisCPM-Paint`支持中英双语的文到图生成。该模型使用CPM-Bee(10B)作为文本编码器,使用`UNet`作为图像解码器,并通过扩散模型训练目标融合语言和视觉模型。在训练过程中,语言模型参数始终保持固定。我们使用[Stable Diffusion 2.1](https://github.com/Stability-AI/stablediffusion)的UNet参数初始化视觉解码器,并通过逐步解冻其中关键的桥接参数将其与语言模型融合:首先训练文本表示映射到视觉模型的线性层,然后进一步解冻`UNet`的交叉注意力层。该模型在[LAION 2B](https://laion.ai/)英文图文对数据上进行了训练。
122
+
123
+ 与`VisCPM-Chat`类似,我们发现得益于CPM-Bee的双语能力,`VisCPM-Paint`可以仅通过英文图文对训练,泛化实现良好的中文文到图生成能力,达到中文开源模型的最佳效果。通过进一步加入20M清洗后的原生中文图文对数据,以及120M翻译到中文的图文对数据,模型的中文文到图生成能力可以获得进一步提升。我们在MSCOCO上采样了3万张图片,计算了FID(Fréchet Inception Distance)和Clip Score,前者用于评估生成图片的质量,后面用于评估生成的图片与输入的匹配程度。
124
+
125
+ <table>
126
+ <tr>
127
+ <td align="center" rowspan="2">模型</td>
128
+ <td align="center" colspan="2">英文</td>
129
+ <td align="center" colspan="2">中文</td>
130
+ </tr>
131
  <tr>
132
+ <td align="center">FID↓</td>
133
+ <td align="center">CLIP Score↑</td>
134
+ <td align="center">FID↓</td>
135
+ <td align="center">CLIP Score↑</td>
136
+ </tr>
137
+ <tr>
138
+ <td align="center">AltDiffusion</td>
139
+ <td align="center">17.16</td>
140
+ <td align="center">25.24</td>
141
+ <td align="center">16.09</td>
142
+ <td align="center">24.05</td>
143
+ </tr>
144
+ <tr>
145
+ <td align="center">TaiyiDiffusion</td>
146
+ <td align="center">-</td>
147
+ <td align="center">-</td>
148
+ <td align="center">15.58</td>
149
+ <td align="center">22.69</td>
150
+ </tr>
151
+ <tr>
152
+ <td align="center">Stable Diffusion</td>
153
+ <td align="center">9.08</td>
154
+ <td align="center">26.22</td>
155
+ <td align="center">-</td>
156
+ <td align="center">-</td>
157
+ </tr>
158
+ <tr>
159
+ <td align="center">VisCPM-Paint-en</td>
160
+ <td align="center">9.51</td>
161
+ <td align="center">25.35</td>
162
+ <td align="center">10.86</td>
163
+ <td align="center">23.38</td>
164
+ </tr>
165
+ <tr>
166
+ <td align="center">VisCPM-Paint-zh</td>
167
+ <td align="center">9.98</td>
168
+ <td align="center">25.04</td>
169
+ <td align="center">9.65</td>
170
+ <td align="center">24.17</td>
171
  </tr>
172
  </table>
173
 
174
+ # 安装
175
+
176
+ ```Shell
177
+ conda create -n viscpm python=3.10 -y
178
+ conda activate viscpm
179
+ pip install setuptools
180
+ pip install diffusers jieba matplotlib numpy opencv_python
181
+ pip install pandas Pillow psutil pydantic scipy
182
+ pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
183
+ pip install transformers==4.28.0
184
+ pip install tqdm typing_extensions
185
+ pip install git+https://github.com/thunlp/OpenDelta.git
186
+ pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
187
+ ```
188
+
189
+ VisCPM需要单卡40GB以上的GPU运行,我们会在尽快更新更加节省显存的推理方式。
190
+
191
+ ## 使用
192
 
193
+ ```python
194
+ >>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
195
+ >>> from PIL import Image
196
 
197
+ >>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
198
+ >>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
199
+ >>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
200
 
201
+ >>> data = [{
202
+ >>> 'context': '',
203
+ >>> 'question': 'describe this image in detail.',
204
+ >>> 'image': tokenizer.unk_token * model.query_num,
205
+ >>> '<ans>': ''
206
+ >>> }]
207
+ >>> image = Image.open('case.jpg')
208
+ >>> result = model.generate(data, tokenizer, processor, image)
209
+ >>> print(result[0]['<ans>'])
210
+ 这幅图片显示了一群热气球在天空中飞行。这些热气球漂浮在不同的地方,包括山脉、城市和乡村地区。
211
+ ```
README_en.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VisCPM
2
+ [简体中文](README.md) | English
3
+
4
+ <p align="center">
5
+ <p align="left">
6
+ <a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
7
+ <a href=""><img src="https://img.shields.io/badge/python-3.8+-aff.svg"></a>
8
+ </p>
9
+
10
+ `VisCPM` is a family of open-source large multimodal models, which support multimodal conversational capabilities (`VisCPM-Chat` model) and text-to-image generation capabilities (`VisCPM-Paint` model) in both Chinese and English, achieving state-of-the-art peformance among Chinese open-source multimodal models. `VisCPM` is trained based on the large language model [CPM-Bee](https://github.com/OpenBMB/CPM-Bee) with 10B parameters, fusing visual encoder (`Q-Former`) and visual decoder (`Diffusion-UNet`) to support visual inputs and outputs. Thanks to the good bilingual capability of `CPM-Bee`, `VisCPM` can be pre-trained with English multimodal data only and well generalize to achieve promising Chinese multimodal capabilities.
11
+
12
+ ## VisCPM-Chat
13
+ `VisCPM-Chat` supports bilingual multimodal conversations involving images in both Chinese and English. The model utilizes `Q-Former` as the visual encoder and CPM-Bee (10B) as the base LLM. It combines visual and language models through language modeling training objectives. The model training consists of two stages: pretraining and instruction fine-tuning.
14
+
15
+ * Pretrain: `VisCPM-Chat` was pretrained using approximately 100 million high-quality English multimodal data pairs. The data sources include CC3M, CC12M, COCO, Visual Genome, Laion, and others. In this stage, the language model parameters remain fixed, and only the parameters of the `Q-Former` are updated to enable efficient alignment of large-scale visual-language representations.
16
+
17
+ * Instruction fine-tuning: We utilized the [LLaVA-150K](https://llava-vl.github.io/) dataset, which consists of English multimodal instruction-following dataset. We mixed this data with corresponding translated Chinese data to fine-tune the model and align its multimodal capabilities with user intents. In this phase, we updated all model parameters to improve the utilization efficiency of the instruction fine-tuning data. Interestingly, we observed that even when using only English instruction data for fine-tuning, the model can comprehend Chinese questions but can only respond in English. This indicates that the model has achieved good generalization in terms of its multilingual and multimodal capabilities. By incorporating a small amount of translated Chinese data during the instruction fine-tuning phase, we can align the model's response language with the user's question language.
18
+
19
+ We evaluated the model on the LLaVA English test set and the translated Chinese test set. The evaluation benchmark examined the model's performance in open-domain conversations, image detail descriptions, and complex reasoning tasks, using GPT-4 for scoring. It is evident that `VisCPM-Chat` achieved the best average performance in Chinese multimodal capabilities, excelling in general-domain conversations and complex reasoning. It also demonstrated commendable English multimodal abilities.
20
+
21
+ <table>
22
+ <tr>
23
+ <td align="center" rowspan="2" colspan="2">Model</td>
24
+ <td align="center" colspan="4">English</td>
25
+ <td align="center" colspan="4">Chinese</td>
26
+ </tr>
27
+ <tr>
28
+ <td align="center">Conversation</td>
29
+ <td align="center">Detailed Description</td>
30
+ <td align="center">Complex Reasoning</td>
31
+ <td align="center">All</td>
32
+ <td align="center">Conversation</td>
33
+ <td align="center">Detailed Description</td>
34
+ <td align="center">Complex Reasoning</td>
35
+ <td align="center">All</td>
36
+ </tr>
37
+ <tr>
38
+ <td align="center" rowspan="3">English Model</td>
39
+ <td align="center">MiniGPT4</td>
40
+ <td align="center">65</td>
41
+ <td align="center">67.3</td>
42
+ <td align="center">76.6</td>
43
+ <td align="center">69.7</td>
44
+ <td align="center">-</td>
45
+ <td align="center">-</td>
46
+ <td align="center">-</td>
47
+ <td align="center">-</td>
48
+ </tr>
49
+ <tr>
50
+ <td align="center">InstructBLIP</td>
51
+ <td align="center">81.9</td>
52
+ <td align="center">68</td>
53
+ <td align="center">91.2</td>
54
+ <td align="center">80.5</td>
55
+ <td align="center">-</td>
56
+ <td align="center">-</td>
57
+ <td align="center">-</td>
58
+ <td align="center">-</td>
59
+ </tr>
60
+ <tr>
61
+ <td align="center">LLaVA</td>
62
+ <td align="center">89.5</td>
63
+ <td align="center">70.4</td>
64
+ <td align="center">96.2</td>
65
+ <td align="center">85.6</td>
66
+ <td align="center">-</td>
67
+ <td align="center">-</td>
68
+ <td align="center">-</td>
69
+ <td align="center">-</td>
70
+ </tr>
71
+ <tr>
72
+ <td align="center" rowspan="4">En-Zh Bilingual Model</td>
73
+ <td align="center">mPLUG-Owl </td>
74
+ <td align="center">64.6</td>
75
+ <td align="center">47.7</td>
76
+ <td align="center">80.1</td>
77
+ <td align="center">64.2</td>
78
+ <td align="center">76.3</td>
79
+ <td align="center">61.2</td>
80
+ <td align="center">77.8</td>
81
+ <td align="center">72</td>
82
+ </tr>
83
+ <tr>
84
+ <td align="center">VisualGLM</td>
85
+ <td align="center">62.4</td>
86
+ <td align="center">63</td>
87
+ <td align="center">80.6</td>
88
+ <td align="center">68.7</td>
89
+ <td align="center">76.6</td>
90
+ <td align="center">87.8</td>
91
+ <td align="center">83.6</td>
92
+ <td align="center">82.7</td>
93
+ </tr>
94
+ <tr>
95
+ <td align="center">Ziya (LLaMA 13B)</td>
96
+ <td align="center">82.7</td>
97
+ <td align="center">69.9</td>
98
+ <td align="center">92.1</td>
99
+ <td align="center">81.7</td>
100
+ <td align="center">85</td>
101
+ <td align="center">74.7</td>
102
+ <td align="center">82.4</td>
103
+ <td align="center">80.8</td>
104
+ </tr>
105
+ <tr>
106
+ <td align="center">VisCPM-Chat</td>
107
+ <td align="center">83.3</td>
108
+ <td align="center">68.9</td>
109
+ <td align="center">90.5</td>
110
+ <td align="center">81.1</td>
111
+ <td align="center">92.7</td>
112
+ <td align="center">76.1</td>
113
+ <td align="center">89.2</td>
114
+ <td align="center">86.3</td>
115
+ </tr>
116
+ </table>
117
+
118
+ # Install
119
+
120
+ 1. Clone this repository and navigate to source folder
121
+ ```bash
122
+ git clone <github repo URL>
123
+ cd viscpm
124
+ ```
125
+
126
+ 2. Install Package
127
+ ```Shell
128
+ conda create -n viscpm python=3.10 -y
129
+ conda activate viscpm
130
+ pip install setuptools
131
+ pip install diffusers jieba matplotlib numpy opencv_python
132
+ pip install pandas Pillow psutil pydantic scipy
133
+ pip install torch==1.13.1 torchscale==0.2.0 torchvision==0.14.1 timm
134
+ pip install transformers==4.28.0
135
+ pip install tqdm typing_extensions
136
+ pip install git+https://github.com/thunlp/OpenDelta.git
137
+ pip install git+https://github.com/OpenBMB/CPM-Bee.git#egg=cpm-live&subdirectory=src
138
+ ```
139
+
140
+ `VisCPM` require GPUs with more than 40GB memory. We will soon update more memory-friendly inference methods.
141
+
142
+ ## How to use
143
+
144
+ ```python
145
+ >>> from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
146
+ >>> from PIL import Image
147
+
148
+ >>> tokenizer = AutoTokenizer.from_pretrained('viscpm', trust_remote_code=True)
149
+ >>> processor = AutoImageProcessor.from_pretrained('viscpm', trust_remote_code=True)
150
+ >>> model = AutoModel.from_pretrained('viscpm', trust_remote_code=True).to('cuda')
151
+
152
+ >>> data = [{
153
+ >>> 'context': '',
154
+ >>> 'question': 'describe this image in detail.',
155
+ >>> 'image': tokenizer.unk_token * model.query_num,
156
+ >>> '<ans>': ''
157
+ >>> }]
158
+ >>> image = Image.open('case.jpg')
159
+ >>> result = model.generate(data, tokenizer, processor, image)
160
+ >>> print(result[0]['<ans>'])
161
+ 这幅图片显示了一群热气球在天空中飞行。这些热气球漂浮在不同的地方,包括山脉、城市和乡村地区。
162
+ ```
beit3.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
12
+ from timm.models.registry import register_model
13
+
14
+ from torchscale.model.BEiT3 import BEiT3
15
+ from torchscale.architecture.config import EncoderConfig
16
+
17
+
18
+ def trunc_normal_(tensor, mean=0., std=1.):
19
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
20
+
21
+
22
+ def _get_base_config(
23
+ img_size=224, patch_size=16, drop_path_rate=0,
24
+ checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
25
+ ):
26
+ return EncoderConfig(
27
+ img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
28
+ layernorm_embedding=False, normalize_output=True, no_output_layer=True,
29
+ drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
30
+ encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
31
+ checkpoint_activations=checkpoint_activations,
32
+ )
33
+
34
+
35
+ def _get_large_config(
36
+ img_size=224, patch_size=16, drop_path_rate=0,
37
+ checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
38
+ ):
39
+ return EncoderConfig(
40
+ img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
41
+ layernorm_embedding=False, normalize_output=True, no_output_layer=True,
42
+ drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
43
+ encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
44
+ checkpoint_activations=checkpoint_activations,
45
+ )
46
+
47
+
48
+ class BEiT3Wrapper(nn.Module):
49
+ def __init__(self, args, **kwargs):
50
+ super().__init__()
51
+ self.args = args
52
+ self.beit3 = BEiT3(args)
53
+ self.apply(self._init_weights)
54
+ self.mim_head = nn.Linear(1024, 8192)
55
+ self.num_img_patches = self.beit3.vision_embed.num_position_embeddings()
56
+ self.hidden_size = args.encoder_embed_dim
57
+
58
+ def fix_init_weight(self):
59
+ def rescale(param, layer_id):
60
+ param.div_(math.sqrt(2.0 * layer_id))
61
+
62
+ for layer_id, layer in enumerate(self.blocks):
63
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
64
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
65
+
66
+ def get_num_layers(self):
67
+ return self.beit3.encoder.num_layers
68
+
69
+ @torch.jit.ignore
70
+ def no_weight_decay(self):
71
+ return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
72
+
73
+ def _init_weights(self, m):
74
+ if isinstance(m, nn.Linear):
75
+ trunc_normal_(m.weight, std=.02)
76
+ if isinstance(m, nn.Linear) and m.bias is not None:
77
+ nn.init.constant_(m.bias, 0)
78
+ elif isinstance(m, nn.LayerNorm):
79
+ nn.init.constant_(m.bias, 0)
80
+ nn.init.constant_(m.weight, 1.0)
81
+
82
+ def forward(self, pixel_values, query_embed=None):
83
+ B = pixel_values.size(0)
84
+ dtype = self.beit3.vision_embed.proj.weight.dtype
85
+ pixel_values = pixel_values.to(dtype)
86
+ token_embeddings = self.beit3.vision_embed(pixel_values)
87
+ multiway_split_position = -1
88
+ if query_embed is not None:
89
+ query_embed = torch.stack([query_embed] * B)
90
+ multiway_split_position = token_embeddings.size(1)
91
+ token_embeddings = torch.cat([token_embeddings, query_embed], dim=1)
92
+
93
+ outputs = self.beit3.encoder(
94
+ src_tokens=None,
95
+ token_embeddings=token_embeddings,
96
+ multiway_split_position=multiway_split_position
97
+ )
98
+ vision_hidden_states = outputs["encoder_out"]
99
+ if query_embed is not None:
100
+ vision_hidden_states = vision_hidden_states[:, self.num_img_patches:]
101
+ return vision_hidden_states
102
+
103
+
104
+ @register_model
105
+ def beit3_large_patch16_224(pretrained=False, **kwargs):
106
+ args = _get_large_config(img_size=224, **kwargs)
107
+ model = BEiT3Wrapper(args, **kwargs)
108
+ return model
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "_name_or_path": "openbmb/viscpmchat-bee-10b",
4
+ "architectures": [
5
+ "VisCpmBeeForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_viscpmchatbee.VisCpmChatBeeConfig",
9
+ "AutoModel": "modeling_cpmbee.VisCpmBeeForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_cpmbee.VisCpmBeeForCausalLM"
11
+ },
12
+ "vocab_size": 86583,
13
+ "hidden_size": 4096,
14
+ "dim_ff" : 10240,
15
+ "num_hidden_layers" : 48,
16
+ "num_attention_heads": 32,
17
+ "dim_head" : 128,
18
+ "dropout_p" : 0.0,
19
+ "position_bias_num_buckets" : 256,
20
+ "position_bias_num_segment_buckets": 256,
21
+ "position_bias_max_distance" : 2048,
22
+ "vision_dim": 1024,
23
+ "query_num": 64,
24
+ "eps" : 1e-6,
25
+ "half" : true,
26
+ "model_type": "viscpmchatbee"
27
+ }
configuration_viscpmchatbee.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CpmBee model configuration"""
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/resolve/main/config.json",
27
+ # See all VisCpmBee models at https://huggingface.co/models?filter=viscpmbee
28
+ }
29
+
30
+
31
+ class VisCpmChatBeeConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
34
+ CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
35
+ with the defaults will yield a similar configuration to that of the CPMBee
36
+ [openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 30720):
43
+ Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
44
+ `input` passed when calling [`CpmBeeModel`].
45
+ hidden_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the encoder layers.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads in the Transformer encoder.
49
+ dim_head (`int`, *optional*, defaults to 128):
50
+ Dimension of attention heads for each attention layer in the Transformer encoder.
51
+ dim_ff (`int`, *optional*, defaults to 10240):
52
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
53
+ num_hidden_layers (`int`, *optional*, defaults to 48):
54
+ Number of layers of the Transformer encoder.
55
+ dropout_p (`float`, *optional*, defaults to 0.1):
56
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
57
+ position_bias_num_buckets (`int`, *optional*, defaults to 512):
58
+ The number of position_bias buckets.
59
+ position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
60
+ The number of segment buckets.
61
+ position_bias_max_distance (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
63
+ just in case (e.g., 512 or 1024 or 2048).
64
+ eps (`float`, *optional*, defaults to 1e-6):
65
+ The epsilon used by the layer normalization layers.
66
+ init_std (`float`, *optional*, defaults to 1.0):
67
+ Initialize parameters with std = init_std.
68
+ use_cache (`bool`, *optional*, defaults to `True`):
69
+ Whether to use cache.
70
+ distance_scale (`float` or `int`, *optional*, defaults to 16):
71
+ Scale the rotary embedding.
72
+ mask_modules (`list` or `tuple`, *optional*, defaults to None):
73
+ Decides which feedforward block or attention block is pruned.
74
+ half (`bool`, *optional*, defaults to `False`):
75
+ Decides the model parameters are half-precision or not.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import CpmBeeModel, CpmBeeConfig
81
+
82
+ >>> # Initializing a CPMBee cpm-bee-10b style configuration
83
+ >>> configuration = CpmBeeConfig()
84
+
85
+ >>> # Initializing a model from the cpm-bee-10b style configuration
86
+ >>> model = CpmBeeModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+ model_type = "viscpmchatbee"
92
+
93
+ def __init__(
94
+ self,
95
+ vocab_size: int = 30720,
96
+ hidden_size: int = 4096,
97
+ num_attention_heads: int = 64,
98
+ dim_head: int = 64,
99
+ dim_ff: int = 10240,
100
+ num_hidden_layers: int = 32,
101
+ dropout_p: int = 0.0,
102
+ position_bias_num_buckets: int = 256,
103
+ position_bias_num_segment_buckets: int = 32,
104
+ position_bias_max_distance: int = 2048,
105
+ eps: int = 1e-6,
106
+ init_std: float = 1.0,
107
+ use_cache: bool = True,
108
+ distance_scale: Union[int, float] = 16,
109
+ mask_modules: Optional[Union[List, Tuple]] = None,
110
+ half: bool = False,
111
+ vision_dim: int = 1024,
112
+ query_num: int = 64,
113
+ **kwargs,
114
+ ):
115
+ super().__init__(**kwargs)
116
+ self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
117
+ self.hidden_size = hidden_size
118
+ self.num_attention_heads = num_attention_heads
119
+ self.dim_head = dim_head
120
+ self.dim_ff = dim_ff
121
+ self.num_hidden_layers = num_hidden_layers
122
+ self.position_bias_num_buckets = position_bias_num_buckets
123
+ self.position_bias_max_distance = position_bias_max_distance
124
+ self.dropout_p = dropout_p
125
+ self.eps = eps
126
+ self.use_cache = use_cache
127
+ self.vocab_size = vocab_size
128
+ self.init_std = init_std
129
+ self.distance_scale = distance_scale
130
+ self.half = half
131
+ self.mask_modules = mask_modules
132
+ self.vision_dim = vision_dim
133
+ self.query_num = query_num
feature_extraction_viscpmchatbee.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from transformers.utils import logging
4
+ from processing_viscpmchatbee import VisCpmChatBeeImageProcessor
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class VisCpmChatBeeFeatureExtractor(VisCpmChatBeeImageProcessor):
11
+ def __init__(self, *args, **kwargs) -> None:
12
+ warnings.warn(
13
+ "The class VisCpmBeeFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
14
+ " use CLIPImageProcessor instead.",
15
+ FutureWarning,
16
+ )
17
+ super().__init__(*args, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_beams": 3,
3
+ "num_beam_groups": 1,
4
+ "do_sample": false,
5
+ "is_constraint_gen_mode": false,
6
+ "is_contrastive_search_gen_mode": false,
7
+ "pad_token_id": 0,
8
+ "eos_token_id": 7,
9
+ "bos_token_id": 6,
10
+ "max_new_tokens": 100,
11
+ "vocab_size": 86583
12
+ }
modeling_cpmbee.py ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor_type": "VisCpmChatBeeImageProcessor",
3
+ "is_train": false,
4
+ "randaug": false,
5
+ "input_size": 224,
6
+ "interpolation": "bicubic",
7
+ "auto_map": {
8
+ "AutoImageProcessor": "processing_viscpmchatbee.VisCpmChatBeeImageProcessor"
9
+ }
10
+ }
processing_viscpmchatbee.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
5
+ from timm.data.transforms import RandomResizedCropAndInterpolation
6
+ from torchvision import transforms
7
+ import urllib
8
+ from tqdm import tqdm
9
+ from cpm_live.tokenizers import CPMBeeTokenizer
10
+ from torch.utils.data import default_collate
11
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
+ from typing_extensions import TypedDict
13
+ from numpy.typing import NDArray
14
+ import importlib.machinery
15
+ import importlib.util
16
+ import types
17
+ import random
18
+ from transformers.image_utils import make_list_of_images
19
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
20
+ from transformers import TensorType
21
+ import json
22
+
23
+
24
+ # aug functions
25
+ def identity_func(img):
26
+ return img
27
+
28
+
29
+ def autocontrast_func(img, cutoff=0):
30
+ '''
31
+ same output as PIL.ImageOps.autocontrast
32
+ '''
33
+ n_bins = 256
34
+
35
+ def tune_channel(ch):
36
+ n = ch.size
37
+ cut = cutoff * n // 100
38
+ if cut == 0:
39
+ high, low = ch.max(), ch.min()
40
+ else:
41
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
42
+ low = np.argwhere(np.cumsum(hist) > cut)
43
+ low = 0 if low.shape[0] == 0 else low[0]
44
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
45
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
46
+ if high <= low:
47
+ table = np.arange(n_bins)
48
+ else:
49
+ scale = (n_bins - 1) / (high - low)
50
+ table = np.arange(n_bins) * scale - low * scale
51
+ table[table < 0] = 0
52
+ table[table > n_bins - 1] = n_bins - 1
53
+ table = table.clip(0, 255).astype(np.uint8)
54
+ return table[ch]
55
+
56
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
57
+ out = cv2.merge(channels)
58
+ return out
59
+
60
+
61
+ def equalize_func(img):
62
+ '''
63
+ same output as PIL.ImageOps.equalize
64
+ PIL's implementation is different from cv2.equalize
65
+ '''
66
+ n_bins = 256
67
+
68
+ def tune_channel(ch):
69
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
70
+ non_zero_hist = hist[hist != 0].reshape(-1)
71
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
72
+ if step == 0:
73
+ return ch
74
+ n = np.empty_like(hist)
75
+ n[0] = step // 2
76
+ n[1:] = hist[:-1]
77
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
78
+ return table[ch]
79
+
80
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
81
+ out = cv2.merge(channels)
82
+ return out
83
+
84
+
85
+ def rotate_func(img, degree, fill=(0, 0, 0)):
86
+ '''
87
+ like PIL, rotate by degree, not radians
88
+ '''
89
+ H, W = img.shape[0], img.shape[1]
90
+ center = W / 2, H / 2
91
+ M = cv2.getRotationMatrix2D(center, degree, 1)
92
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
93
+ return out
94
+
95
+
96
+ def solarize_func(img, thresh=128):
97
+ '''
98
+ same output as PIL.ImageOps.posterize
99
+ '''
100
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
101
+ table = table.clip(0, 255).astype(np.uint8)
102
+ out = table[img]
103
+ return out
104
+
105
+
106
+ def color_func(img, factor):
107
+ '''
108
+ same output as PIL.ImageEnhance.Color
109
+ '''
110
+ # implementation according to PIL definition, quite slow
111
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
112
+ # out = blend(degenerate, img, factor)
113
+ # M = (
114
+ # np.eye(3) * factor
115
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
116
+ # )[np.newaxis, np.newaxis, :]
117
+ M = (
118
+ np.float32([
119
+ [0.886, -0.114, -0.114],
120
+ [-0.587, 0.413, -0.587],
121
+ [-0.299, -0.299, 0.701]]) * factor
122
+ + np.float32([[0.114], [0.587], [0.299]])
123
+ )
124
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
125
+ return out
126
+
127
+
128
+ def contrast_func(img, factor):
129
+ """
130
+ same output as PIL.ImageEnhance.Contrast
131
+ """
132
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
133
+ table = np.array([(
134
+ el - mean) * factor + mean
135
+ for el in range(256)
136
+ ]).clip(0, 255).astype(np.uint8)
137
+ out = table[img]
138
+ return out
139
+
140
+
141
+ def brightness_func(img, factor):
142
+ '''
143
+ same output as PIL.ImageEnhance.Contrast
144
+ '''
145
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
146
+ out = table[img]
147
+ return out
148
+
149
+
150
+ def sharpness_func(img, factor):
151
+ '''
152
+ The differences the this result and PIL are all on the 4 boundaries, the center
153
+ areas are same
154
+ '''
155
+ kernel = np.ones((3, 3), dtype=np.float32)
156
+ kernel[1][1] = 5
157
+ kernel /= 13
158
+ degenerate = cv2.filter2D(img, -1, kernel)
159
+ if factor == 0.0:
160
+ out = degenerate
161
+ elif factor == 1.0:
162
+ out = img
163
+ else:
164
+ out = img.astype(np.float32)
165
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
166
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
167
+ out = out.astype(np.uint8)
168
+ return out
169
+
170
+
171
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
174
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
175
+ return out
176
+
177
+
178
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
179
+ '''
180
+ same output as PIL.Image.transform
181
+ '''
182
+ H, W = img.shape[0], img.shape[1]
183
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
184
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
185
+ return out
186
+
187
+
188
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
189
+ '''
190
+ same output as PIL.Image.transform
191
+ '''
192
+ H, W = img.shape[0], img.shape[1]
193
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
194
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
195
+ return out
196
+
197
+
198
+ def posterize_func(img, bits):
199
+ '''
200
+ same output as PIL.ImageOps.posterize
201
+ '''
202
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
203
+ return out
204
+
205
+
206
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
207
+ H, W = img.shape[0], img.shape[1]
208
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
209
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
210
+ return out
211
+
212
+
213
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
214
+ replace = np.array(replace, dtype=np.uint8)
215
+ H, W = img.shape[0], img.shape[1]
216
+ rh, rw = np.random.random(2)
217
+ pad_size = pad_size // 2
218
+ ch, cw = int(rh * H), int(rw * W)
219
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
220
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
221
+ out = img.copy()
222
+ out[x1:x2, y1:y2, :] = replace
223
+ return out
224
+
225
+
226
+ # level to args
227
+ def enhance_level_to_args(MAX_LEVEL):
228
+ def level_to_args(level):
229
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
230
+ return level_to_args
231
+
232
+
233
+ def shear_level_to_args(MAX_LEVEL, replace_value):
234
+ def level_to_args(level):
235
+ level = (level / MAX_LEVEL) * 0.3
236
+ if np.random.random() > 0.5:
237
+ level = -level
238
+ return (level, replace_value)
239
+
240
+ return level_to_args
241
+
242
+
243
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
244
+ def level_to_args(level):
245
+ level = (level / MAX_LEVEL) * float(translate_const)
246
+ if np.random.random() > 0.5:
247
+ level = -level
248
+ return (level, replace_value)
249
+
250
+ return level_to_args
251
+
252
+
253
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
254
+ def level_to_args(level):
255
+ level = int((level / MAX_LEVEL) * cutout_const)
256
+ return (level, replace_value)
257
+
258
+ return level_to_args
259
+
260
+
261
+ def solarize_level_to_args(MAX_LEVEL):
262
+ def level_to_args(level):
263
+ level = int((level / MAX_LEVEL) * 256)
264
+ return (level, )
265
+ return level_to_args
266
+
267
+
268
+ def none_level_to_args(level):
269
+ return ()
270
+
271
+
272
+ def posterize_level_to_args(MAX_LEVEL):
273
+ def level_to_args(level):
274
+ level = int((level / MAX_LEVEL) * 4)
275
+ return (level, )
276
+ return level_to_args
277
+
278
+
279
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
280
+ def level_to_args(level):
281
+ level = (level / MAX_LEVEL) * 30
282
+ if np.random.random() < 0.5:
283
+ level = -level
284
+ return (level, replace_value)
285
+
286
+ return level_to_args
287
+
288
+
289
+ func_dict = {
290
+ 'Identity': identity_func,
291
+ 'AutoContrast': autocontrast_func,
292
+ 'Equalize': equalize_func,
293
+ 'Rotate': rotate_func,
294
+ 'Solarize': solarize_func,
295
+ 'Color': color_func,
296
+ 'Contrast': contrast_func,
297
+ 'Brightness': brightness_func,
298
+ 'Sharpness': sharpness_func,
299
+ 'ShearX': shear_x_func,
300
+ 'TranslateX': translate_x_func,
301
+ 'TranslateY': translate_y_func,
302
+ 'Posterize': posterize_func,
303
+ 'ShearY': shear_y_func,
304
+ }
305
+
306
+ translate_const = 10
307
+ MAX_LEVEL = 10
308
+ replace_value = (128, 128, 128)
309
+ arg_dict = {
310
+ 'Identity': none_level_to_args,
311
+ 'AutoContrast': none_level_to_args,
312
+ 'Equalize': none_level_to_args,
313
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
314
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
315
+ 'Color': enhance_level_to_args(MAX_LEVEL),
316
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
317
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
318
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
319
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
320
+ 'TranslateX': translate_level_to_args(
321
+ translate_const, MAX_LEVEL, replace_value
322
+ ),
323
+ 'TranslateY': translate_level_to_args(
324
+ translate_const, MAX_LEVEL, replace_value
325
+ ),
326
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
327
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
328
+ }
329
+
330
+
331
+ class RandomAugment(object):
332
+
333
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
334
+ self.N = N
335
+ self.M = M
336
+ self.isPIL = isPIL
337
+ if augs:
338
+ self.augs = augs
339
+ else:
340
+ self.augs = list(arg_dict.keys())
341
+
342
+ def get_random_ops(self):
343
+ sampled_ops = np.random.choice(self.augs, self.N)
344
+ return [(op, 0.5, self.M) for op in sampled_ops]
345
+
346
+ def __call__(self, img):
347
+ if self.isPIL:
348
+ img = np.array(img)
349
+ ops = self.get_random_ops()
350
+ for name, prob, level in ops:
351
+ if np.random.random() > prob:
352
+ continue
353
+ args = arg_dict[name](level)
354
+ img = func_dict[name](img, *args)
355
+ return img
356
+
357
+
358
+ def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
359
+ if is_train:
360
+ t = [
361
+ RandomResizedCropAndInterpolation(
362
+ input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
363
+ transforms.RandomHorizontalFlip(),
364
+ ]
365
+ if randaug:
366
+ t.append(
367
+ RandomAugment(
368
+ 2, 7, isPIL=True,
369
+ augs=[
370
+ 'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
371
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
372
+ ]))
373
+ t += [
374
+ transforms.ToTensor(),
375
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
376
+ ]
377
+ t = transforms.Compose(t)
378
+ else:
379
+ t = transforms.Compose([
380
+ transforms.Resize((input_size, input_size),
381
+ interpolation=transforms.InterpolationMode.BICUBIC),
382
+ transforms.ToTensor(),
383
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
384
+ ])
385
+
386
+ return t
387
+
388
+
389
+ class VisCpmChatBeeImageProcessor(BaseImageProcessor):
390
+ def __init__(self, is_train, randaug=True, input_size=224, interpolation='bicubic', **kwargs):
391
+ super().__init__(**kwargs)
392
+ self.is_train = is_train
393
+ self.randaug = randaug
394
+ self.input_size = input_size
395
+ self.interpolation = interpolation
396
+ self._transform = build_transform(is_train, randaug=randaug, input_size=input_size, interpolation=interpolation)
397
+
398
+ def preprocess(self, images, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs) -> BatchFeature:
399
+ images = make_list_of_images(images)
400
+ images = [self._transform(image) for image in images]
401
+ images = torch.tensor([image.numpy() for image in images])
402
+
403
+ data = {"pixel_values": images}
404
+ return BatchFeature(data=data, tensor_type=return_tensors)
405
+
406
+ def to_json_string(self) -> str:
407
+ """
408
+ Serializes this instance to a JSON string.
409
+
410
+ Returns:
411
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
412
+ """
413
+ dictionary = self.to_dict()
414
+
415
+ for key, value in dictionary.items():
416
+ if isinstance(value, np.ndarray):
417
+ dictionary[key] = value.tolist()
418
+
419
+ # make sure private name "_processor_class" is correctly
420
+ # saved as "processor_class"
421
+ _processor_class = dictionary.pop("_processor_class", None)
422
+ if _processor_class is not None:
423
+ dictionary["processor_class"] = _processor_class
424
+ _transform = dictionary.pop("_transform", None)
425
+ if _transform is not None:
426
+ dictionary["_transform"] = str(type(_transform))
427
+
428
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
tokenization_viscpmchatbee.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for CpmBee."""
16
+ import json
17
+ import os
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ from numpy.typing import NDArray
22
+ from typing_extensions import TypedDict
23
+
24
+ from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
25
+ from transformers.tokenization_utils_base import AddedToken, BatchEncoding, TextInput, TruncationStrategy
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
32
+
33
+ PRETRAINED_VOCAB_FILES_MAP = {
34
+ "vocab_file": {
35
+ "openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/blob/main/vocab.txt",
36
+ },
37
+ }
38
+
39
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
40
+ "openbmb/viscpmchat-bee-10b": 4096,
41
+ }
42
+
43
+
44
+ class _PrevExtTableStates(TypedDict):
45
+ ext_table: Dict[int, str]
46
+ token_id_table: Dict[str, Dict[int, int]]
47
+
48
+
49
+ CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
50
+
51
+
52
+ def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
53
+ ret = n_up * max_depth + n_down
54
+ if ret == 0:
55
+ return ret
56
+ else:
57
+ # bucket 1 is reserved for incontext samples
58
+ return ret + 1
59
+
60
+
61
+ class _DictTree(TypedDict):
62
+ value: str
63
+ children: List["_DictTree"]
64
+ depth: int
65
+ segment_id: int
66
+ need_predict: bool
67
+ is_image: bool
68
+
69
+
70
+ class VisCpmChatBeeTokenizer(PreTrainedTokenizer):
71
+ """
72
+ Construct a CPMBee tokenizer.
73
+
74
+ Args:
75
+ vocab_file (`str`):
76
+ Path to the vocabulary file.
77
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
78
+ The beginning of sequence token.
79
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
80
+ The end of sequence token.
81
+ line_token (`str`, *optional*, defaults to `"\n"`):
82
+ The line token.
83
+ space_token (`str`, *optional*, defaults to `" "`):
84
+ The space token.
85
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
86
+ The unknown token.
87
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
88
+ The mask token.
89
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
90
+ The token used for padding.
91
+ padding_side (`str`, *optional*, defaults to `"left"`):
92
+ The padding side. CPM-Bee will use left padding by default.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ model_input_names: List[str] = [
99
+ "input_ids",
100
+ "attention_mask",
101
+ "input_id_sub",
102
+ "position",
103
+ "context",
104
+ "sample_ids",
105
+ "num_segments",
106
+ "segment",
107
+ "segment_rel_offset",
108
+ "segment_rel",
109
+ ]
110
+ add_prefix_space = False
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_file,
115
+ bos_token="<s>",
116
+ eos_token="</s>",
117
+ line_token="\n",
118
+ space_token=" ",
119
+ unk_token="<unk>",
120
+ mask_token="<mask>",
121
+ pad_token="<pad>",
122
+ padding_side="left",
123
+ **kwargs,
124
+ ):
125
+ super().__init__(
126
+ bos_token=bos_token,
127
+ eos_token=eos_token,
128
+ line_token=line_token,
129
+ space_token=space_token,
130
+ unk_token=unk_token,
131
+ mask_token=mask_token,
132
+ pad_token=pad_token,
133
+ padding_side=padding_side,
134
+ **kwargs,
135
+ )
136
+
137
+ self.encoder: Dict[str, int] = {}
138
+
139
+ with open(vocab_file, "r", encoding="utf-8") as reader:
140
+ for token in reader.readlines():
141
+ token = token.rstrip("\n")
142
+ if len(token) == 0:
143
+ continue
144
+ self.encoder[token] = len(self.encoder)
145
+
146
+ self.encoder[" "] = self.encoder["</_>"]
147
+ self.encoder["\n"] = self.encoder["</n>"]
148
+ del self.encoder["</_>"]
149
+ del self.encoder["</n>"]
150
+
151
+ self.decoder = {v: k for k, v in self.encoder.items()}
152
+
153
+ self._max_word_len = max([len(x) for x in self.encoder.keys()])
154
+ self.cpmbee_special_tokens = {k: v for k, v in self.encoder.items() if k.startswith("<") and k.endswith(">")}
155
+
156
+ self.ext_table: Dict[int, str] = {}
157
+ self.ext_table_rev: Dict[str, int] = {}
158
+
159
+ self.token_id_table: Dict[str, Dict[int, int]] = {}
160
+ self.ext_special_tokens = []
161
+
162
+ self.ext_args_for_model = [
163
+ "input_id_subs",
164
+ "input_pos",
165
+ "context",
166
+ "segment_ids",
167
+ "segment_rel_offset",
168
+ "segment_rel",
169
+ "sample_ids",
170
+ "num_segments",
171
+ "predict_segments",
172
+ "answer_placeholders",
173
+ "ext_table",
174
+ "token_id_table",
175
+ "image_bound"
176
+ ]
177
+
178
+ @property
179
+ def bod_token_id(self):
180
+ return self.encoder[self.bod_token]
181
+
182
+ @property
183
+ def eod_token_id(self):
184
+ return self.encoder[self.eod_token]
185
+
186
+ @property
187
+ def newline_id(self):
188
+ return self.encoder[self.line_token]
189
+
190
+ @property
191
+ def vocab_size(self) -> int:
192
+ return len(self.encoder)
193
+
194
+ def __len__(self):
195
+ """
196
+ Size of the full vocabulary with the added tokens.
197
+ """
198
+ return self.vocab_size + len(self.added_tokens_encoder)
199
+
200
+ def get_vocab(self):
201
+ return dict(self.encoder, **self.added_tokens_encoder)
202
+
203
+ def get_piece(self, text: str) -> str:
204
+ """
205
+ Match with maximum length.
206
+ """
207
+ len_text = len(text)
208
+ for i in range(len(text)):
209
+ sub = text[: len_text - i]
210
+ if (sub in self.encoder) or (sub in self.added_tokens_encoder):
211
+ return sub
212
+ return text[0]
213
+
214
+ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
215
+ r"""
216
+ Override the `tokenize` to meet the needs of CPMBee:
217
+ 1. Mark the special token with `<` and `>`. The `<>` will be ignored.
218
+ 2. Split sentences by the marked special tokens.
219
+ 3. Record the marked special token by `ext_table` and `ext_table_rev`.
220
+ 4. Tokenize the sentence without special tokens.
221
+ """
222
+ for_cpmbee = kwargs.get("for_cpmbee", False)
223
+ all_special_tokens_extended = {
224
+ str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
225
+ }
226
+
227
+ sentence_split = [""]
228
+ is_special_token = False
229
+ for i, c in enumerate(text):
230
+ if is_special_token:
231
+ if c == "<":
232
+ tail = sentence_split.pop(-1)
233
+ sentence_split[-1] += tail
234
+ sentence_split.append(c)
235
+ elif c == ">":
236
+ # end of special token
237
+ sentence_split[-1] += c
238
+ if sentence_split[-1] == "<>":
239
+ continue
240
+ is_special_token = False
241
+ sentence_split.append("")
242
+ else:
243
+ sentence_split[-1] += c
244
+ else:
245
+ if c == "<":
246
+ is_special_token = True
247
+ sentence_split.append(c)
248
+ else:
249
+ sentence_split[-1] += c
250
+ if is_special_token:
251
+ tail = sentence_split.pop(-1)
252
+ sentence_split[-1] += tail
253
+
254
+ output_tokens = []
255
+ for i, part in enumerate(sentence_split):
256
+ if (i & 1) == 1:
257
+ # special token
258
+ output_tokens.append(part)
259
+ if for_cpmbee and (part not in self.encoder) and (part not in self.ext_table_rev):
260
+ self.ext_table_rev[part] = len(self.ext_table_rev) + self.vocab_size
261
+ self.ext_table[self.ext_table_rev[part]] = part
262
+ else:
263
+ output_tokens.extend(self._tokenize(part, for_cpmbee=for_cpmbee))
264
+
265
+ # drop spaces
266
+ for i, token in enumerate(output_tokens):
267
+ if token in self.added_tokens_encoder:
268
+ token = all_special_tokens_extended.get(token, None)
269
+ left = output_tokens[i - 1] if i > 0 else None
270
+ right = output_tokens[i + 1] if i < len(output_tokens) - 1 else None
271
+ if isinstance(token, AddedToken):
272
+ if token.rstrip and right:
273
+ # A bit counter-intuitive but we strip the left of the string
274
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
275
+ output_tokens[i + 1] = right.lstrip()
276
+ # Strip white spaces on the left
277
+ if token.lstrip and left:
278
+ output_tokens[i - 1] = left.rstrip() # Opposite here
279
+ else:
280
+ if right:
281
+ output_tokens[i + 1] = right.lstrip()
282
+ if left:
283
+ output_tokens[i - 1] = left.rstrip()
284
+
285
+ skipped_tokens = []
286
+ for token in output_tokens:
287
+ if not token:
288
+ continue
289
+ else:
290
+ skipped_tokens.append(token)
291
+
292
+ return skipped_tokens
293
+
294
+ def _tokenize(self, text, **kwargs):
295
+ """
296
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
297
+ vocabulary.
298
+
299
+ Do NOT take care of added tokens. Record the unk tokens and special tokens in `ext_table` and `ext_table_rev`.
300
+ """
301
+ for_cpmbee = kwargs.get("for_cpmbee", False)
302
+ output_tokens = []
303
+
304
+ part_st = 0
305
+ last_unk = None
306
+ while part_st < len(text):
307
+ piece = self.get_piece(text[part_st:])
308
+ if piece in self.encoder or self.added_tokens_encoder:
309
+ if last_unk is None:
310
+ output_tokens.append(piece)
311
+ else:
312
+ if for_cpmbee and (last_unk not in self.ext_table_rev):
313
+ self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
314
+ self.ext_table[self.ext_table_rev[last_unk]] = last_unk
315
+ output_tokens.append(last_unk)
316
+ output_tokens.append(piece)
317
+ last_unk = None
318
+ else:
319
+ if last_unk is None:
320
+ last_unk = piece
321
+ else:
322
+ last_unk += piece
323
+ part_st += len(piece)
324
+ if last_unk is not None:
325
+ # part end with UNK
326
+ if for_cpmbee and (last_unk not in self.ext_table_rev):
327
+ self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
328
+ self.ext_table[self.ext_table_rev[last_unk]] = last_unk
329
+ output_tokens.append(last_unk)
330
+
331
+ return output_tokens
332
+
333
+ def check(self, token):
334
+ return token in self.encoder
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ return "".join(tokens)
338
+
339
+ def _convert_token_to_id(self, token: str):
340
+ """Converts a token (str) in an id using the vocab and ext_table."""
341
+ if token in self.encoder:
342
+ return self.encoder.get(token)
343
+ elif token in self.ext_table_rev:
344
+ return self.ext_table_rev[token]
345
+ elif token in self.added_tokens_encoder:
346
+ return self.added_tokens_encoder[token]
347
+ else:
348
+ return self.unk_token_id
349
+
350
+ def _convert_id_to_token(self, index):
351
+ """Converts an index (integer) in a token (str) using the vocab and ext_table."""
352
+ if index in self.ext_table:
353
+ return self.ext_table[index]
354
+ elif index in self.added_tokens_decoder:
355
+ return self.added_tokens_decoder[index]
356
+ else:
357
+ if index >= 0:
358
+ return self.decoder[index]
359
+
360
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
361
+ if os.path.isdir(save_directory):
362
+ vocab_file = os.path.join(
363
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
364
+ )
365
+ else:
366
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
367
+ index = 0
368
+ self.encoder["</n>"] = self.encoder["\n"]
369
+ del self.encoder["\n"]
370
+ self.encoder["</_>"] = self.encoder[" "]
371
+ del self.encoder[" "]
372
+ with open(vocab_file, "w", encoding="utf-8") as writer:
373
+ for token, token_index in sorted(self.encoder.items(), key=lambda x: x[1]):
374
+ if index != token_index:
375
+ logger.warning(
376
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
377
+ " Please check that the vocabulary is not corrupted!"
378
+ )
379
+ index = token_index
380
+ writer.write(token + "\n")
381
+ index += 1
382
+ return (vocab_file,)
383
+
384
+ def __call__(self, text, *args, **kwargs):
385
+ r"""
386
+ CPMBee `call` method will use `_tokenize_cpmbee` when the input type is dict.
387
+ """
388
+ if isinstance(text, dict):
389
+ return self._batch_tokenize_cpmbee([text], *args, **kwargs)
390
+ elif isinstance(text, (list, tuple)):
391
+ if isinstance(text[0], dict):
392
+ return self._batch_tokenize_cpmbee(text, *args, **kwargs)
393
+ else:
394
+ return super().__call__(text, *args, **kwargs)
395
+ else:
396
+ return super().__call__(text, *args, **kwargs)
397
+
398
+ # 分词
399
+ def _tokenize_cpmbee(self, data: TextInput, *args, **kwargs) -> List[str]:
400
+ """
401
+ A tokenize method to process dict data. Exclusive for CPMBee.
402
+ """
403
+ if isinstance(data, str):
404
+ data = json.loads(data)
405
+ if not isinstance(data, Dict):
406
+ raise TypeError(
407
+ "CpmBeeTokenizer input data should be dict or str in dict format, but got {}".format(type(data))
408
+ )
409
+
410
+ # 1. prepare answer placeholder
411
+ answer_placeholders = []
412
+
413
+ def _put_placeholder(data: Any, path: List[str] = []):
414
+ if isinstance(data, dict):
415
+ ret = {}
416
+ for k, v in data.items():
417
+ ret[k] = _put_placeholder(v, path + [k])
418
+ return ret
419
+ else:
420
+ answer_placeholders.append(path)
421
+ return "<ans_{}>".format(len(answer_placeholders))
422
+
423
+ data["<ans>"] = _put_placeholder(data["<ans>"])
424
+
425
+ (
426
+ input_ids,
427
+ input_id_subs,
428
+ context,
429
+ segment_ids,
430
+ segment_rel,
431
+ n_segments,
432
+ table_states,
433
+ image_bound
434
+ ) = self.convert_data_to_id(data, shuffle_answer=False, max_depth=8)
435
+
436
+ # <ans> mapping from sub to id
437
+ sub_ans_map: Dict[int, int] = {}
438
+ for fake_id, token_sub in table_states["token_id_table"]["<ans>"].items():
439
+ token = table_states["ext_table"][fake_id]
440
+ if token.startswith("<ans_") and token.endswith(">"):
441
+ ans_id = int(token[5:-1])
442
+ sub_ans_map[token_sub] = ans_id
443
+
444
+ tmp_input_ids = []
445
+ tmp_input_sub = []
446
+ tmp_input_seg = []
447
+
448
+ # get predict segments
449
+ predict_segments: List[Tuple[int, int]] = []
450
+ for i in range(input_ids.shape[0]):
451
+ if context[i] == 0:
452
+ if input_ids[i] == self.encoder["<ans>"]:
453
+ # is ans
454
+ # (segment_id, ans_id)
455
+ predict_segments.append((segment_ids[i], sub_ans_map[input_id_subs[i]]))
456
+ else:
457
+ tmp_input_ids.append(input_ids[i])
458
+ tmp_input_sub.append(input_id_subs[i])
459
+ tmp_input_seg.append(segment_ids[i])
460
+
461
+ if len(predict_segments) == 0:
462
+ raise ValueError("No answer to predict")
463
+
464
+ input_ids = np.array(tmp_input_ids, dtype=np.int32) # all context
465
+ input_id_subs = np.array(tmp_input_sub, dtype=np.int32) # [0, 0, 0, 0, 1, 0, 0, 2, 0, ...]
466
+ context = np.full_like(tmp_input_ids, 1, dtype=np.int8) # [1, 1, 1, ...]
467
+ segment_ids = np.array(tmp_input_seg, dtype=np.int32) # [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, ...]
468
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, 0, ...]
469
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, ...]
470
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32) # [n_seg, n_seg, n_seg, ...]
471
+ input_pos = np.arange(input_ids.shape[0], dtype=np.int32) # [0, 1, 2, 3, 4, ...]
472
+ image_bound = np.array(image_bound)
473
+
474
+ return (
475
+ self.prepare_for_model(
476
+ input_ids.tolist(),
477
+ input_id_subs=input_id_subs.tolist(),
478
+ input_pos=input_pos.tolist(),
479
+ context=context.tolist(),
480
+ segment_ids=segment_ids.tolist(),
481
+ segment_rel_offset=segment_rel_offset.tolist(),
482
+ segment_rel=segment_rel.tolist(),
483
+ sample_ids=sample_ids.tolist(),
484
+ num_segments=num_segments.tolist(),
485
+ image_bound=image_bound,
486
+ **kwargs,
487
+ ),
488
+ predict_segments,
489
+ answer_placeholders,
490
+ table_states["ext_table"],
491
+ table_states["token_id_table"],
492
+ )
493
+
494
+ def _batch_tokenize_cpmbee(self, data_lst, *args, **kwargs):
495
+ """
496
+ Batched _token_cpmbee.
497
+ """
498
+ device = kwargs.get("device", "cpu")
499
+ return_tensors = kwargs.get("return_tensors", None)
500
+ batch_outputs = {}
501
+ segment_rel_pack = []
502
+ other_info = []
503
+
504
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
505
+ batch_ext_table_ids: List[int] = []
506
+ batch_ext_table_sub: List[int] = []
507
+
508
+ for data in data_lst:
509
+ self.ext_table = {}
510
+ self.ext_table_rev = {}
511
+ self.token_id_table = {}
512
+ (outputs, predict_segments, answer_placeholders, ext_table, token_id_table) = self._tokenize_cpmbee(
513
+ data,
514
+ truncation=None,
515
+ padding=PaddingStrategy.DO_NOT_PAD.value,
516
+ max_length=None,
517
+ pad_to_multiple_of=None,
518
+ return_attention_mask=False,
519
+ return_tensors=None,
520
+ )
521
+ rev_ext_table = {}
522
+ for token, mp in token_id_table.items():
523
+ if token == "<ans>":
524
+ continue
525
+ token_id = self.encoder[token]
526
+ for fake_id, token_sub in mp.items():
527
+ if token_sub > 0:
528
+ if (token_id, token_sub) not in batch_ext_table_map:
529
+ batch_ext_table_map[(token_id, token_sub)] = len(batch_ext_table_ids) + self.vocab_size
530
+ batch_ext_table_ids.append(token_id)
531
+ batch_ext_table_sub.append(token_sub)
532
+ rev_ext_table[batch_ext_table_map[(token_id, token_sub)]] = ext_table[fake_id]
533
+ else:
534
+ rev_ext_table[token_id] = ext_table[fake_id]
535
+
536
+ segment_rel_pack.append(np.array(outputs.pop("segment_rel")))
537
+ other_info.append(
538
+ {
539
+ "predict_segments": predict_segments,
540
+ "answer_placeholders": answer_placeholders,
541
+ "ext_table": rev_ext_table,
542
+ }
543
+ )
544
+
545
+ for key, value in outputs.items():
546
+ if key not in batch_outputs:
547
+ batch_outputs[key] = []
548
+ batch_outputs[key].append(value)
549
+
550
+ max_length = max([len(item) for item in batch_outputs[self.model_input_names[0]]])
551
+ batch_size = len(batch_outputs[self.model_input_names[0]])
552
+ for i in range(batch_size):
553
+ inputs = {k: v[i] for k, v in batch_outputs.items()}
554
+
555
+ for k, v in inputs.items():
556
+ required_input = v
557
+
558
+ needs_to_be_padded = len(required_input) != max_length and k != 'image_bound'
559
+
560
+ if needs_to_be_padded:
561
+ difference = max_length - len(required_input)
562
+ batch_outputs[k][i] = [self.pad_token_id] * difference + required_input
563
+
564
+ max_num_rels = 0
565
+ for rel in segment_rel_pack:
566
+ max_num_rels = max(max_num_rels, rel.shape[0])
567
+ padded_rels = np.zeros((len(segment_rel_pack), max_num_rels), dtype=np.int32)
568
+ for i, rel in enumerate(segment_rel_pack):
569
+ padded_rels[i, : rel.shape[0]] = rel
570
+ batch_outputs["segment_rel"] = padded_rels
571
+ batch_outputs["batch_ext_table_ids"] = np.array(batch_ext_table_ids, dtype=np.int32)
572
+ batch_outputs["batch_ext_table_sub"] = np.array(batch_ext_table_sub, dtype=np.int32)
573
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
574
+ if return_tensors == "pt":
575
+ batch_outputs = batch_outputs.to(device=device)
576
+ batch_outputs["other_info"] = other_info
577
+
578
+ return batch_outputs
579
+
580
+ def convert_data_to_id(
581
+ self,
582
+ data: Any,
583
+ prev_ext_states: Optional[_PrevExtTableStates] = None,
584
+ shuffle_answer: bool = True,
585
+ max_depth: int = 8,
586
+ ):
587
+ """
588
+ Parse a dict to data ids. Exclusive for CPMBee. It will
589
+ 1. parse the dict to segments and get segment_rel, which for calculating of position_bias.
590
+ 2. tokenize every segment.
591
+ """
592
+ root: _DictTree = {
593
+ "value": "<root>",
594
+ "children": [],
595
+ "depth": 0,
596
+ "segment_id": 0,
597
+ "need_predict": False,
598
+ "is_image": False
599
+ }
600
+
601
+ segments = [root]
602
+
603
+ def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
604
+ if isinstance(data, dict):
605
+ ret_list: List[_DictTree] = []
606
+ curr_items = list(data.items())
607
+ if need_predict and shuffle_answer:
608
+ access_idx = np.arange(len(curr_items))
609
+ np.random.shuffle(access_idx)
610
+ curr_items = [curr_items[idx] for idx in access_idx]
611
+ for k, v in curr_items:
612
+ child_info: _DictTree = {
613
+ "value": k,
614
+ "children": [],
615
+ "depth": depth,
616
+ "segment_id": len(segments),
617
+ "need_predict": False, # only leaves are contexts
618
+ "is_image": False,
619
+ }
620
+ segments.append(child_info)
621
+ child_info["children"] = _build_dict_tree(
622
+ v, depth + 1,
623
+ need_predict=need_predict or (depth == 1 and k == "<ans>"),
624
+ is_image=is_image or (depth == 1 and k == "image")
625
+ ) # elements in <root>.<ans>
626
+
627
+ ret_list.append(child_info)
628
+ return ret_list
629
+ else:
630
+ assert isinstance(data, str), "Invalid data {}".format(data)
631
+ ret: _DictTree = {
632
+ "value": data,
633
+ "children": [],
634
+ "depth": depth,
635
+ "segment_id": len(segments),
636
+ "need_predict": need_predict,
637
+ "is_image": is_image,
638
+ }
639
+ segments.append(ret)
640
+ return [ret]
641
+
642
+ root["children"] = _build_dict_tree(data, 1, False, False)
643
+
644
+ num_segments = len(segments)
645
+ segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
646
+
647
+ def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
648
+ ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
649
+ for child in node["children"]:
650
+ sub = _build_segment_rel(child)
651
+ for seg_id_1, depth_1 in sub:
652
+ for seg_id_2, depth_2 in ret:
653
+ n_up = min(depth_1 - node["depth"], max_depth - 1)
654
+ n_down = min(depth_2 - node["depth"], max_depth - 1)
655
+ segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
656
+ n_up, n_down, max_depth=max_depth
657
+ )
658
+ segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
659
+ n_down, n_up, max_depth=max_depth
660
+ )
661
+ ret.extend(sub)
662
+ return ret
663
+
664
+ _build_segment_rel(root)
665
+
666
+ input_ids: List[int] = []
667
+ input_id_subs: List[int] = []
668
+ segment_bound: List[Tuple[int, int]] = []
669
+ image_bound: List[Tuple[int, int]] = []
670
+
671
+
672
+ if prev_ext_states is not None:
673
+ self.ext_table = prev_ext_states["ext_table"]
674
+ self.token_id_table = prev_ext_states["token_id_table"]
675
+
676
+ for seg in segments:
677
+ # tokenize
678
+ tokens = self.convert_tokens_to_ids(self.tokenize(seg["value"], for_cpmbee=True))
679
+
680
+ token_id_subs = []
681
+ reid_token_ids = []
682
+ for idx in tokens:
683
+ if idx in self.ext_table:
684
+ # unk or special token
685
+ token = self.ext_table[idx]
686
+ if token.startswith("<") and token.endswith(">"):
687
+ # special token
688
+ if "_" in token:
689
+ token_name = token[1:-1].split("_", maxsplit=1)[0]
690
+ else:
691
+ token_name = token[1:-1]
692
+ token_name = "<{}>".format(token_name)
693
+ else:
694
+ token_name = "<unk>"
695
+
696
+ if token_name not in self.token_id_table:
697
+ self.token_id_table[token_name] = {}
698
+ if idx not in self.token_id_table[token_name]:
699
+ self.token_id_table[token_name][idx] = len(self.token_id_table[token_name])
700
+ if token_name not in self.encoder:
701
+ raise ValueError("Invalid token {}".format(token))
702
+ reid_token_ids.append(self.encoder[token_name])
703
+ token_id_subs.append(self.token_id_table[token_name][idx])
704
+ else:
705
+ reid_token_ids.append(idx)
706
+ token_id_subs.append(0)
707
+ tokens = [self.bos_token_id] + reid_token_ids
708
+ token_id_subs = [0] + token_id_subs
709
+ # eos_id 表示 no need_predict
710
+ if not seg["need_predict"]: # eos
711
+ tokens = tokens + [self.eos_token_id]
712
+ token_id_subs = token_id_subs + [0]
713
+ else:
714
+ # no eos
715
+ pass
716
+ begin = len(input_ids)
717
+ input_ids.extend(tokens)
718
+ input_id_subs.extend(token_id_subs)
719
+ end = len(input_ids)
720
+ segment_bound.append((begin, end))
721
+
722
+ ids = np.array(input_ids, dtype=np.int32)
723
+ id_subs = np.array(input_id_subs, dtype=np.int32)
724
+ segs = np.zeros((ids.shape[0],), dtype=np.int32) # 按segment_bound对seg编号
725
+ context = np.zeros((ids.shape[0],), dtype=np.int8)
726
+ for i, (begin, end) in enumerate(segment_bound):
727
+ if not segments[i]["need_predict"]:
728
+ context[begin:end] = 1
729
+ if segments[i]["is_image"]:
730
+ image_bound.append((begin + 1, end - 1))
731
+ segs[begin:end] = i
732
+
733
+ curr_ext_table_states: _PrevExtTableStates = {
734
+ "ext_table": self.ext_table,
735
+ "token_id_table": self.token_id_table,
736
+ }
737
+ image_bound = np.array(image_bound, dtype=np.int32)
738
+ return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
739
+
740
+ def prepare_for_model(
741
+ self,
742
+ ids: List[int],
743
+ pair_ids: Optional[List[int]] = None,
744
+ add_special_tokens: bool = True,
745
+ padding: Union[bool, str, PaddingStrategy] = False,
746
+ truncation: Union[bool, str, TruncationStrategy] = None,
747
+ max_length: Optional[int] = None,
748
+ stride: int = 0,
749
+ pad_to_multiple_of: Optional[int] = None,
750
+ return_tensors: Optional[Union[str, TensorType]] = None,
751
+ return_token_type_ids: Optional[bool] = None,
752
+ return_attention_mask: Optional[bool] = None,
753
+ return_overflowing_tokens: bool = False,
754
+ return_special_tokens_mask: bool = False,
755
+ return_length: bool = False,
756
+ verbose: bool = True,
757
+ prepend_batch_axis: bool = False,
758
+ **kwargs,
759
+ ) -> BatchEncoding:
760
+ """
761
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
762
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
763
+ manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
764
+ different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
765
+ overflowing tokens. Such a combination of arguments will raise an error.
766
+
767
+ Args:
768
+ ids (`List[int]`):
769
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
770
+ `convert_tokens_to_ids` methods.
771
+ pair_ids (`List[int]`, *optional*):
772
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
773
+ and `convert_tokens_to_ids` methods.
774
+ """
775
+
776
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
777
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
778
+ padding=padding,
779
+ truncation=truncation,
780
+ max_length=max_length,
781
+ pad_to_multiple_of=pad_to_multiple_of,
782
+ verbose=verbose,
783
+ **kwargs,
784
+ )
785
+
786
+ pair = bool(pair_ids is not None)
787
+ len_ids = len(ids)
788
+ len_pair_ids = len(pair_ids) if pair else 0
789
+
790
+ if return_token_type_ids and not add_special_tokens:
791
+ raise ValueError(
792
+ "Asking to return token_type_ids while setting add_special_tokens to False "
793
+ "results in an undefined behavior. Please set add_special_tokens to True or "
794
+ "set return_token_type_ids to None."
795
+ )
796
+
797
+ if (
798
+ return_overflowing_tokens
799
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
800
+ and pair_ids is not None
801
+ ):
802
+ raise ValueError(
803
+ "Not possible to return overflowing tokens for pair of sequences with the "
804
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
805
+ "for instance `only_second` or `only_first`."
806
+ )
807
+
808
+ # Load from model defaults
809
+ if return_token_type_ids is None:
810
+ return_token_type_ids = "token_type_ids" in self.model_input_names
811
+ if return_attention_mask is None:
812
+ return_attention_mask = "attention_mask" in self.model_input_names
813
+
814
+ encoded_inputs = {}
815
+
816
+ # Compute the total size of the returned encodings
817
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
818
+
819
+ # Truncation: Handle max sequence length
820
+ overflowing_tokens = []
821
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
822
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
823
+ ids,
824
+ pair_ids=pair_ids,
825
+ num_tokens_to_remove=total_len - max_length,
826
+ truncation_strategy=truncation_strategy,
827
+ stride=stride,
828
+ )
829
+
830
+ if return_overflowing_tokens:
831
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
832
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
833
+
834
+ # Add special tokens
835
+ if add_special_tokens:
836
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
837
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
838
+ else:
839
+ sequence = ids + pair_ids if pair else ids
840
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
841
+
842
+ # Build output dictionary
843
+ encoded_inputs["input_ids"] = sequence
844
+ if return_token_type_ids:
845
+ encoded_inputs["token_type_ids"] = token_type_ids
846
+ if return_special_tokens_mask:
847
+ if add_special_tokens:
848
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
849
+ else:
850
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
851
+
852
+ # Check lengths
853
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
854
+
855
+ # Padding
856
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
857
+ encoded_inputs = self.pad(
858
+ encoded_inputs,
859
+ max_length=max_length,
860
+ padding=padding_strategy.value,
861
+ pad_to_multiple_of=pad_to_multiple_of,
862
+ return_attention_mask=return_attention_mask,
863
+ )
864
+
865
+ if return_length:
866
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
867
+
868
+ # for CPMBee, encode all the model arguments
869
+ for arg in self.ext_args_for_model:
870
+ v = kwargs.get(arg, None)
871
+ if v is not None:
872
+ encoded_inputs[arg] = v
873
+
874
+ batch_outputs = BatchEncoding(
875
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
876
+ )
877
+
878
+ return batch_outputs
879
+
880
+ def prepare_for_finetune(
881
+ self,
882
+ data_list: List[Dict],
883
+ max_length: int = 2048
884
+ ):
885
+ _inputs: List[NDArray[np.int32]] = []
886
+ _inputs_sub: List[NDArray[np.int32]] = []
887
+ _context: List[NDArray[np.int8]] = []
888
+ _sample_ids: List[NDArray[np.int32]] = []
889
+ _segments: List[NDArray[np.int32]] = []
890
+ _num_segments: List[NDArray[np.int32]] = []
891
+ _segment_rel_offset: List[NDArray[np.int32]] = []
892
+ _segment_rel: List[NDArray[np.int32]] = []
893
+ _spans: List[List[int]] = []
894
+ _raw_data: List[List[Any]] = []
895
+
896
+ raw_data = {}
897
+ for data in data_list:
898
+ (
899
+ input_ids,
900
+ input_id_subs,
901
+ context,
902
+ segment_ids,
903
+ segment_rel,
904
+ n_segments,
905
+ _
906
+ ) = self.convert_data_to_id(data)
907
+
908
+ input_ids = input_ids[: max_length]
909
+ context = context[: max_length]
910
+ segment_ids = segment_ids[: max_length]
911
+ raw_data["input"] = data
912
+ raw_data["samples"] = []
913
+
914
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
915
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
916
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
917
+
918
+ _inputs.append(input_ids)
919
+ _inputs_sub.append(input_id_subs)
920
+ _context.append(context)
921
+ _sample_ids.append(sample_ids)
922
+ _segments.append(segment_ids)
923
+ _num_segments.append(num_segments)
924
+ _segment_rel_offset.append(segment_rel_offset)
925
+ _segment_rel.append(segment_rel)
926
+ _spans.append([input_ids.shape[0]])
927
+ _raw_data.append([raw_data])
928
+
929
+ batch_size = len(_inputs)
930
+ inputs = np.zeros((batch_size, max_length), dtype=np.int32)
931
+ inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
932
+ context = np.zeros((batch_size, max_length), dtype=np.int8)
933
+ sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
934
+ segments = np.zeros((batch_size, max_length), dtype=np.int32)
935
+ num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
936
+ segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
937
+ tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
938
+
939
+ max_rel = 0
940
+ for i in range(batch_size):
941
+ max_rel = max(max_rel, _segment_rel[i].shape[0])
942
+ segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
943
+ spans = np.zeros((batch_size, max_length), dtype=np.int32)
944
+ length = np.zeros((batch_size,), dtype=np.int32)
945
+
946
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
947
+ batch_ext_table_ids: List[int] = []
948
+ batch_ext_table_sub: List[int] = []
949
+ raw_data_list: List[Any] = []
950
+
951
+ for i in range(batch_size):
952
+ instance_length = _inputs[i].shape[0]
953
+ rel_size = _segment_rel[i].shape[0]
954
+ inputs[i, :instance_length] = _inputs[i]
955
+ inputs_sub[i, :instance_length] = _inputs_sub[i]
956
+ context[i, :instance_length] = _context[i]
957
+ sample_ids[i, :instance_length] = _sample_ids[i]
958
+ segments[i, :instance_length] = _segments[i]
959
+ num_segments[i, :instance_length] = _num_segments[i]
960
+ segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
961
+ segment_rel[i, :rel_size] = _segment_rel[i]
962
+
963
+ span_begin = 0
964
+ for span_id, span_end in enumerate(_spans[i]):
965
+ spans[i, span_begin:span_end] = span_id
966
+ span_begin = span_end
967
+ length[i] = instance_length
968
+ raw_data_list.extend(_raw_data[i])
969
+
970
+ for j in range(instance_length):
971
+ idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
972
+ tgt_idx = idx
973
+ if idx_sub > 0:
974
+ # need to be in ext table
975
+ if (idx, idx_sub) not in batch_ext_table_map:
976
+ batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
977
+ batch_ext_table_ids.append(idx)
978
+ batch_ext_table_sub.append(idx_sub)
979
+ tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
980
+ if j > 1 and context[i, j - 1] == 0:
981
+ if idx != self.bos_token_id:
982
+ tgt[i, j - 1] = tgt_idx
983
+ else:
984
+ tgt[i, j - 1] = self.eos_token_id
985
+ if context[i, instance_length - 1] == 0:
986
+ tgt[i, instance_length - 1] = self.eos_token_id
987
+
988
+ if len(batch_ext_table_map) == 0:
989
+ # placeholder
990
+ batch_ext_table_ids.append(0)
991
+ batch_ext_table_sub.append(1)
992
+
993
+ return BatchEncoding({
994
+ "input_ids": inputs,
995
+ "input_id_sub": inputs_sub,
996
+ "length": length,
997
+ "context": context > 0,
998
+ "sample_ids": sample_ids,
999
+ "num_segments": num_segments,
1000
+ "segment": segments,
1001
+ "segment_rel_offset": segment_rel_offset,
1002
+ "segment_rel": segment_rel,
1003
+ "span": spans,
1004
+ "labels": tgt,
1005
+ "ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
1006
+ "ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
1007
+ }, tensor_type="pt")
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "openbmb/viscpmchat-bee-10b",
3
+ "tokenizer_class": "VisCpmChatBeeTokenizer",
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_viscpmchatbee.VisCpmChatBeeTokenizer",
7
+ null
8
+ ]
9
+ }
10
+ }
utils.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
5
+ from timm.data.transforms import RandomResizedCropAndInterpolation
6
+ from torchvision import transforms
7
+ import urllib
8
+ from tqdm import tqdm
9
+ from cpm_live.tokenizers import CPMBeeTokenizer
10
+ from torch.utils.data import default_collate
11
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
+ from typing_extensions import TypedDict
13
+ from numpy.typing import NDArray
14
+ import importlib.machinery
15
+ import importlib.util
16
+ import types
17
+ import random
18
+
19
+
20
+ CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
21
+
22
+
23
+ def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
24
+ items = []
25
+ if isinstance(orig_items[0][key], list):
26
+ assert isinstance(orig_items[0][key][0], torch.Tensor)
27
+ for it in orig_items:
28
+ for tr in it[key]:
29
+ items.append({key: tr})
30
+ else:
31
+ assert isinstance(orig_items[0][key], torch.Tensor)
32
+ items = orig_items
33
+
34
+ batch_size = len(items)
35
+ shape = items[0][key].shape
36
+ dim = len(shape)
37
+ assert dim <= 3
38
+ if max_length is None:
39
+ max_length = 0
40
+ max_length = max(max_length, max(item[key].shape[-1] for item in items))
41
+ min_length = min(item[key].shape[-1] for item in items)
42
+ dtype = items[0][key].dtype
43
+
44
+ if dim == 1:
45
+ return torch.cat([item[key] for item in items], dim=0)
46
+ elif dim == 2:
47
+ if max_length == min_length:
48
+ return torch.cat([item[key] for item in items], dim=0)
49
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
50
+ else:
51
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
52
+
53
+ for i, item in enumerate(items):
54
+ if dim == 2:
55
+ if padding_side == "left":
56
+ tensor[i, -len(item[key][0]):] = item[key][0].clone()
57
+ else:
58
+ tensor[i, : len(item[key][0])] = item[key][0].clone()
59
+ elif dim == 3:
60
+ if padding_side == "left":
61
+ tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
62
+ else:
63
+ tensor[i, : len(item[key][0]), :] = item[key][0].clone()
64
+
65
+ return tensor
66
+
67
+
68
+ class CPMBeeCollater:
69
+ """
70
+ 针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker
71
+ 目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning
72
+ 并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持
73
+ todo: @wangchongyi 重写一下 Dataloader or BatchPacker
74
+ """
75
+
76
+ def __init__(self, tokenizer: CPMBeeTokenizer, max_len):
77
+ self.tokenizer = tokenizer
78
+ self._max_length = max_len
79
+ self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset',
80
+ 'segment_rel', 'sample_ids', 'num_segments']
81
+
82
+ def __call__(self, batch):
83
+ batch_size = len(batch)
84
+
85
+ tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32)
86
+ # 目前没有 best_fit, span 为全 0
87
+ span = np.zeros((batch_size, self._max_length), dtype=np.int32)
88
+ length = np.zeros((batch_size,), dtype=np.int32)
89
+
90
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
91
+ batch_ext_table_ids: List[int] = []
92
+ batch_ext_table_sub: List[int] = []
93
+ raw_data_list: List[Any] = []
94
+
95
+ for i in range(batch_size):
96
+ instance_length = batch[i]['input_ids'][0].shape[0]
97
+ length[i] = instance_length
98
+ raw_data_list.extend(batch[i]['raw_data'])
99
+
100
+ for j in range(instance_length):
101
+ idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j]
102
+ tgt_idx = idx
103
+ if idx_sub > 0:
104
+ # need to be in ext table
105
+ if (idx, idx_sub) not in batch_ext_table_map:
106
+ batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
107
+ batch_ext_table_ids.append(idx)
108
+ batch_ext_table_sub.append(idx_sub)
109
+ tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size
110
+ if j > 1 and batch[i]['context'][0, j - 1] == 0:
111
+ if idx != self.tokenizer.bos_id:
112
+ tgt[i, j - 1] = tgt_idx
113
+ else:
114
+ tgt[i, j - 1] = self.tokenizer.eos_id
115
+ if batch[i]['context'][0, instance_length - 1] == 0:
116
+ tgt[i, instance_length - 1] = self.tokenizer.eos_id
117
+
118
+ if len(batch_ext_table_map) == 0:
119
+ # placeholder
120
+ batch_ext_table_ids.append(0)
121
+ batch_ext_table_sub.append(1)
122
+
123
+ # image
124
+ if 'pixel_values' in batch[0]:
125
+ data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])}
126
+ else:
127
+ data = {}
128
+
129
+ # image_bound
130
+ if 'image_bound' in batch[0]:
131
+ data['image_bound'] = default_collate([i['image_bound'] for i in batch])
132
+
133
+ # bee inp
134
+ for key in self.pad_keys:
135
+ data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right')
136
+
137
+ data['context'] = data['context'] > 0
138
+ data['length'] = torch.from_numpy(length)
139
+ data['span'] = torch.from_numpy(span)
140
+ data['target'] = torch.from_numpy(tgt)
141
+ data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids))
142
+ data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub))
143
+ data['raw_data'] = raw_data_list
144
+
145
+ return data
146
+
147
+
148
+ class _DictTree(TypedDict):
149
+ value: str
150
+ children: List["_DictTree"]
151
+ depth: int
152
+ segment_id: int
153
+ need_predict: bool
154
+ is_image: bool
155
+
156
+
157
+ class _PrevExtTableStates(TypedDict):
158
+ ext_table: Dict[int, str]
159
+ token_id_table: Dict[str, Dict[int, int]]
160
+
161
+
162
+ class _TransformFuncDict(TypedDict):
163
+ loader: importlib.machinery.SourceFileLoader
164
+ module: types.ModuleType
165
+ last_m: float
166
+
167
+
168
+ _TransformFunction = Callable[[CPMBeeInputType, int, random.Random], CPMBeeInputType]
169
+
170
+
171
+ class CPMBeeBatch(TypedDict):
172
+ inputs: NDArray[np.int32]
173
+ inputs_sub: NDArray[np.int32]
174
+ length: NDArray[np.int32]
175
+ context: NDArray[np.bool_]
176
+ sample_ids: NDArray[np.int32]
177
+ num_segments: NDArray[np.int32]
178
+ segment_ids: NDArray[np.int32]
179
+ segment_rel_offset: NDArray[np.int32]
180
+ segment_rel: NDArray[np.int32]
181
+ spans: NDArray[np.int32]
182
+ target: NDArray[np.int32]
183
+ ext_ids: NDArray[np.int32]
184
+ ext_sub: NDArray[np.int32]
185
+ task_ids: NDArray[np.int32]
186
+ task_names: List[str]
187
+ raw_data: List[Any]
188
+
189
+
190
+ def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
191
+ ret = n_up * max_depth + n_down
192
+ if ret == 0:
193
+ return ret
194
+ else:
195
+ # bucket 1 is reserved for incontext samples
196
+ return ret + 1
197
+
198
+
199
+ def convert_data_to_id(
200
+ tokenizer: CPMBeeTokenizer,
201
+ data: Any,
202
+ prev_ext_states: Optional[_PrevExtTableStates] = None,
203
+ shuffle_answer: bool = True,
204
+ max_depth: int = 8
205
+ ):
206
+ root: _DictTree = {
207
+ "value": "<root>",
208
+ "children": [],
209
+ "depth": 0,
210
+ "segment_id": 0,
211
+ "need_predict": False,
212
+ "is_image": False
213
+ }
214
+
215
+ segments = [root]
216
+
217
+ def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
218
+ if isinstance(data, dict):
219
+ ret_list: List[_DictTree] = []
220
+ curr_items = list(data.items())
221
+ if need_predict and shuffle_answer:
222
+ access_idx = np.arange(len(curr_items))
223
+ np.random.shuffle(access_idx)
224
+ curr_items = [curr_items[idx] for idx in access_idx]
225
+ for k, v in curr_items:
226
+ child_info: _DictTree = {
227
+ "value": k,
228
+ "children": [],
229
+ "depth": depth,
230
+ "segment_id": len(segments),
231
+ "need_predict": False, # only leaves are contexts
232
+ "is_image": False,
233
+ }
234
+ segments.append(child_info)
235
+ child_info["children"] = _build_dict_tree(
236
+ v, depth + 1,
237
+ need_predict=need_predict or (depth == 1 and k == "<ans>"),
238
+ is_image=is_image or (depth == 1 and k == "image")
239
+ ) # elements in <root>.<ans>
240
+
241
+ ret_list.append(child_info)
242
+ return ret_list
243
+ else:
244
+ assert isinstance(data, str), "Invalid data {}".format(data)
245
+ ret: _DictTree = {
246
+ "value": data,
247
+ "children": [],
248
+ "depth": depth,
249
+ "segment_id": len(segments),
250
+ "need_predict": need_predict,
251
+ "is_image": is_image,
252
+ }
253
+ segments.append(ret)
254
+ return [ret]
255
+
256
+ root["children"] = _build_dict_tree(data, 1, False, False)
257
+
258
+ num_segments = len(segments)
259
+ segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
260
+
261
+ def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
262
+ ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
263
+ for child in node["children"]:
264
+ sub = _build_segment_rel(child)
265
+ for seg_id_1, depth_1 in sub:
266
+ for seg_id_2, depth_2 in ret:
267
+ n_up = min(depth_1 - node["depth"], max_depth - 1)
268
+ n_down = min(depth_2 - node["depth"], max_depth - 1)
269
+ segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
270
+ n_up, n_down, max_depth=max_depth
271
+ )
272
+ segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
273
+ n_down, n_up, max_depth=max_depth
274
+ )
275
+ ret.extend(sub)
276
+ return ret
277
+
278
+ _build_segment_rel(root)
279
+
280
+ input_ids: List[int] = []
281
+ input_id_subs: List[int] = []
282
+ segment_bound: List[Tuple[int, int]] = []
283
+ image_bound: List[Tuple[int, int]] = []
284
+
285
+ ext_table: Dict[int, str] = {}
286
+ token_id_table: Dict[str, Dict[int, int]] = {}
287
+
288
+ if prev_ext_states is not None:
289
+ ext_table = prev_ext_states["ext_table"]
290
+ token_id_table = prev_ext_states["token_id_table"]
291
+
292
+ for seg in segments:
293
+ tokens, ext_table = tokenizer.encode(seg["value"], ext_table)
294
+
295
+ token_id_subs = []
296
+ reid_token_ids = []
297
+ for idx in tokens:
298
+ if idx in ext_table:
299
+ # unk or special token
300
+ token = ext_table[idx]
301
+ if token.startswith("<") and token.endswith(">"):
302
+ # special token
303
+ if "_" in token:
304
+ token_name = token[1:-1].split("_", maxsplit=1)[0]
305
+ else:
306
+ token_name = token[1:-1]
307
+ token_name = "<{}>".format(token_name)
308
+ else:
309
+ token_name = "<unk>"
310
+
311
+ if token_name not in token_id_table:
312
+ token_id_table[token_name] = {}
313
+ if idx not in token_id_table[token_name]:
314
+ token_id_table[token_name][idx] = len(token_id_table[token_name])
315
+ if token_name not in tokenizer.encoder:
316
+ raise ValueError("Invalid token {}".format(token))
317
+ reid_token_ids.append(tokenizer.encoder[token_name])
318
+ token_id_subs.append(token_id_table[token_name][idx])
319
+ else:
320
+ reid_token_ids.append(idx)
321
+ token_id_subs.append(0)
322
+ tokens = [tokenizer.bos_id] + reid_token_ids
323
+ token_id_subs = [0] + token_id_subs
324
+ if not seg["need_predict"]:
325
+ tokens = tokens + [tokenizer.eos_id]
326
+ token_id_subs = token_id_subs + [0]
327
+ else:
328
+ # no eos
329
+ pass
330
+ begin = len(input_ids)
331
+ input_ids.extend(tokens)
332
+ input_id_subs.extend(token_id_subs)
333
+ end = len(input_ids)
334
+ segment_bound.append((begin, end))
335
+
336
+ ids = np.array(input_ids, dtype=np.int32)
337
+ id_subs = np.array(input_id_subs, dtype=np.int32)
338
+ segs = np.zeros((ids.shape[0],), dtype=np.int32)
339
+ context = np.zeros((ids.shape[0],), dtype=np.int8)
340
+ for i, (begin, end) in enumerate(segment_bound):
341
+ if not segments[i]["need_predict"]:
342
+ context[begin:end] = 1
343
+ if segments[i]["is_image"]:
344
+ image_bound.append((begin+1, end-1))
345
+ segs[begin:end] = i
346
+
347
+ curr_ext_table_states: _PrevExtTableStates = {
348
+ "ext_table": ext_table,
349
+ "token_id_table": token_id_table,
350
+ }
351
+ image_bound = np.array(image_bound, dtype=np.int32)
352
+ return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
353
+
354
+
355
+ # aug functions
356
+ def identity_func(img):
357
+ return img
358
+
359
+
360
+ def autocontrast_func(img, cutoff=0):
361
+ '''
362
+ same output as PIL.ImageOps.autocontrast
363
+ '''
364
+ n_bins = 256
365
+
366
+ def tune_channel(ch):
367
+ n = ch.size
368
+ cut = cutoff * n // 100
369
+ if cut == 0:
370
+ high, low = ch.max(), ch.min()
371
+ else:
372
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
373
+ low = np.argwhere(np.cumsum(hist) > cut)
374
+ low = 0 if low.shape[0] == 0 else low[0]
375
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
376
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
377
+ if high <= low:
378
+ table = np.arange(n_bins)
379
+ else:
380
+ scale = (n_bins - 1) / (high - low)
381
+ table = np.arange(n_bins) * scale - low * scale
382
+ table[table < 0] = 0
383
+ table[table > n_bins - 1] = n_bins - 1
384
+ table = table.clip(0, 255).astype(np.uint8)
385
+ return table[ch]
386
+
387
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
388
+ out = cv2.merge(channels)
389
+ return out
390
+
391
+
392
+ def equalize_func(img):
393
+ '''
394
+ same output as PIL.ImageOps.equalize
395
+ PIL's implementation is different from cv2.equalize
396
+ '''
397
+ n_bins = 256
398
+
399
+ def tune_channel(ch):
400
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
401
+ non_zero_hist = hist[hist != 0].reshape(-1)
402
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
403
+ if step == 0:
404
+ return ch
405
+ n = np.empty_like(hist)
406
+ n[0] = step // 2
407
+ n[1:] = hist[:-1]
408
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
409
+ return table[ch]
410
+
411
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
412
+ out = cv2.merge(channels)
413
+ return out
414
+
415
+
416
+ def rotate_func(img, degree, fill=(0, 0, 0)):
417
+ '''
418
+ like PIL, rotate by degree, not radians
419
+ '''
420
+ H, W = img.shape[0], img.shape[1]
421
+ center = W / 2, H / 2
422
+ M = cv2.getRotationMatrix2D(center, degree, 1)
423
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
424
+ return out
425
+
426
+
427
+ def solarize_func(img, thresh=128):
428
+ '''
429
+ same output as PIL.ImageOps.posterize
430
+ '''
431
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
432
+ table = table.clip(0, 255).astype(np.uint8)
433
+ out = table[img]
434
+ return out
435
+
436
+
437
+ def color_func(img, factor):
438
+ '''
439
+ same output as PIL.ImageEnhance.Color
440
+ '''
441
+ # implementation according to PIL definition, quite slow
442
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
443
+ # out = blend(degenerate, img, factor)
444
+ # M = (
445
+ # np.eye(3) * factor
446
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
447
+ # )[np.newaxis, np.newaxis, :]
448
+ M = (
449
+ np.float32([
450
+ [0.886, -0.114, -0.114],
451
+ [-0.587, 0.413, -0.587],
452
+ [-0.299, -0.299, 0.701]]) * factor
453
+ + np.float32([[0.114], [0.587], [0.299]])
454
+ )
455
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
456
+ return out
457
+
458
+
459
+ def contrast_func(img, factor):
460
+ """
461
+ same output as PIL.ImageEnhance.Contrast
462
+ """
463
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
464
+ table = np.array([(
465
+ el - mean) * factor + mean
466
+ for el in range(256)
467
+ ]).clip(0, 255).astype(np.uint8)
468
+ out = table[img]
469
+ return out
470
+
471
+
472
+ def brightness_func(img, factor):
473
+ '''
474
+ same output as PIL.ImageEnhance.Contrast
475
+ '''
476
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
477
+ out = table[img]
478
+ return out
479
+
480
+
481
+ def sharpness_func(img, factor):
482
+ '''
483
+ The differences the this result and PIL are all on the 4 boundaries, the center
484
+ areas are same
485
+ '''
486
+ kernel = np.ones((3, 3), dtype=np.float32)
487
+ kernel[1][1] = 5
488
+ kernel /= 13
489
+ degenerate = cv2.filter2D(img, -1, kernel)
490
+ if factor == 0.0:
491
+ out = degenerate
492
+ elif factor == 1.0:
493
+ out = img
494
+ else:
495
+ out = img.astype(np.float32)
496
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
497
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
498
+ out = out.astype(np.uint8)
499
+ return out
500
+
501
+
502
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
503
+ H, W = img.shape[0], img.shape[1]
504
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
505
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
506
+ return out
507
+
508
+
509
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
510
+ '''
511
+ same output as PIL.Image.transform
512
+ '''
513
+ H, W = img.shape[0], img.shape[1]
514
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
515
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
516
+ return out
517
+
518
+
519
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
520
+ '''
521
+ same output as PIL.Image.transform
522
+ '''
523
+ H, W = img.shape[0], img.shape[1]
524
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
525
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
526
+ return out
527
+
528
+
529
+ def posterize_func(img, bits):
530
+ '''
531
+ same output as PIL.ImageOps.posterize
532
+ '''
533
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
534
+ return out
535
+
536
+
537
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
538
+ H, W = img.shape[0], img.shape[1]
539
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
540
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
541
+ return out
542
+
543
+
544
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
545
+ replace = np.array(replace, dtype=np.uint8)
546
+ H, W = img.shape[0], img.shape[1]
547
+ rh, rw = np.random.random(2)
548
+ pad_size = pad_size // 2
549
+ ch, cw = int(rh * H), int(rw * W)
550
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
551
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
552
+ out = img.copy()
553
+ out[x1:x2, y1:y2, :] = replace
554
+ return out
555
+
556
+
557
+ # level to args
558
+ def enhance_level_to_args(MAX_LEVEL):
559
+ def level_to_args(level):
560
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
561
+ return level_to_args
562
+
563
+
564
+ def shear_level_to_args(MAX_LEVEL, replace_value):
565
+ def level_to_args(level):
566
+ level = (level / MAX_LEVEL) * 0.3
567
+ if np.random.random() > 0.5:
568
+ level = -level
569
+ return (level, replace_value)
570
+
571
+ return level_to_args
572
+
573
+
574
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
575
+ def level_to_args(level):
576
+ level = (level / MAX_LEVEL) * float(translate_const)
577
+ if np.random.random() > 0.5:
578
+ level = -level
579
+ return (level, replace_value)
580
+
581
+ return level_to_args
582
+
583
+
584
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
585
+ def level_to_args(level):
586
+ level = int((level / MAX_LEVEL) * cutout_const)
587
+ return (level, replace_value)
588
+
589
+ return level_to_args
590
+
591
+
592
+ def solarize_level_to_args(MAX_LEVEL):
593
+ def level_to_args(level):
594
+ level = int((level / MAX_LEVEL) * 256)
595
+ return (level, )
596
+ return level_to_args
597
+
598
+
599
+ def none_level_to_args(level):
600
+ return ()
601
+
602
+
603
+ def posterize_level_to_args(MAX_LEVEL):
604
+ def level_to_args(level):
605
+ level = int((level / MAX_LEVEL) * 4)
606
+ return (level, )
607
+ return level_to_args
608
+
609
+
610
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
611
+ def level_to_args(level):
612
+ level = (level / MAX_LEVEL) * 30
613
+ if np.random.random() < 0.5:
614
+ level = -level
615
+ return (level, replace_value)
616
+
617
+ return level_to_args
618
+
619
+
620
+ func_dict = {
621
+ 'Identity': identity_func,
622
+ 'AutoContrast': autocontrast_func,
623
+ 'Equalize': equalize_func,
624
+ 'Rotate': rotate_func,
625
+ 'Solarize': solarize_func,
626
+ 'Color': color_func,
627
+ 'Contrast': contrast_func,
628
+ 'Brightness': brightness_func,
629
+ 'Sharpness': sharpness_func,
630
+ 'ShearX': shear_x_func,
631
+ 'TranslateX': translate_x_func,
632
+ 'TranslateY': translate_y_func,
633
+ 'Posterize': posterize_func,
634
+ 'ShearY': shear_y_func,
635
+ }
636
+
637
+ translate_const = 10
638
+ MAX_LEVEL = 10
639
+ replace_value = (128, 128, 128)
640
+ arg_dict = {
641
+ 'Identity': none_level_to_args,
642
+ 'AutoContrast': none_level_to_args,
643
+ 'Equalize': none_level_to_args,
644
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
645
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
646
+ 'Color': enhance_level_to_args(MAX_LEVEL),
647
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
648
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
649
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
650
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
651
+ 'TranslateX': translate_level_to_args(
652
+ translate_const, MAX_LEVEL, replace_value
653
+ ),
654
+ 'TranslateY': translate_level_to_args(
655
+ translate_const, MAX_LEVEL, replace_value
656
+ ),
657
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
658
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
659
+ }
660
+
661
+
662
+ class RandomAugment(object):
663
+
664
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
665
+ self.N = N
666
+ self.M = M
667
+ self.isPIL = isPIL
668
+ if augs:
669
+ self.augs = augs
670
+ else:
671
+ self.augs = list(arg_dict.keys())
672
+
673
+ def get_random_ops(self):
674
+ sampled_ops = np.random.choice(self.augs, self.N)
675
+ return [(op, 0.5, self.M) for op in sampled_ops]
676
+
677
+ def __call__(self, img):
678
+ if self.isPIL:
679
+ img = np.array(img)
680
+ ops = self.get_random_ops()
681
+ for name, prob, level in ops:
682
+ if np.random.random() > prob:
683
+ continue
684
+ args = arg_dict[name](level)
685
+ img = func_dict[name](img, *args)
686
+ return img
687
+
688
+
689
+ def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'):
690
+ if is_train:
691
+ t = [
692
+ RandomResizedCropAndInterpolation(
693
+ input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
694
+ transforms.RandomHorizontalFlip(),
695
+ ]
696
+ if randaug:
697
+ t.append(
698
+ RandomAugment(
699
+ 2, 7, isPIL=True,
700
+ augs=[
701
+ 'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
702
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
703
+ ]))
704
+ t += [
705
+ transforms.ToTensor(),
706
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
707
+ ]
708
+ t = transforms.Compose(t)
709
+ else:
710
+ t = transforms.Compose([
711
+ transforms.Resize((input_size, input_size),
712
+ interpolation=transforms.InterpolationMode.BICUBIC),
713
+ transforms.ToTensor(),
714
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
715
+ ])
716
+
717
+ return t
718
+
719
+
720
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
721
+ with open(filename, "wb") as fh:
722
+ with urllib.request.urlopen(
723
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
724
+ ) as response:
725
+ with tqdm(total=response.length) as pbar:
726
+ for chunk in iter(lambda: response.read(chunk_size), ""):
727
+ if not chunk:
728
+ break
729
+ pbar.update(chunk_size)
730
+ fh.write(chunk)
vocab.txt ADDED
The diff for this file is too large to render. See raw diff