kunyi
commited on
Commit
•
f76d30f
1
Parent(s):
330aea7
Upload 30 files
Browse files- README.md +279 -3
- README_CN.md +281 -0
- clip/__init__.py +5 -0
- clip/bert_tokenizer.py +436 -0
- clip/configuration_bert.py +86 -0
- clip/model.py +914 -0
- clip/model_configs/RBT3-chinese.json +13 -0
- clip/model_configs/RN50.json +7 -0
- clip/model_configs/RoBERTa-wwm-ext-base-chinese.json +13 -0
- clip/model_configs/RoBERTa-wwm-ext-large-chinese.json +13 -0
- clip/model_configs/ViT-B-16.json +7 -0
- clip/model_configs/ViT-B-32.json +7 -0
- clip/model_configs/ViT-H-14.json +8 -0
- clip/model_configs/ViT-L-14-336.json +7 -0
- clip/model_configs/ViT-L-14.json +7 -0
- clip/modeling_bert.py +484 -0
- clip/utils.py +184 -0
- clip/vocab.txt +0 -0
- eval/cvinw_zeroshot_templates.py +474 -0
- eval/data.py +164 -0
- eval/evaluation.py +157 -0
- eval/evaluation_tr.py +157 -0
- eval/extract_features.py +212 -0
- eval/make_topk_predictions.py +88 -0
- eval/make_topk_predictions_tr.py +88 -0
- eval/transform_ir_annotation_to_tr.py +36 -0
- eval/zeroshot_evaluation.py +267 -0
- examples/pokemon.jpeg +0 -0
- requirements.txt +10 -0
- scripts/zeroshot_eval.sh +34 -0
README.md
CHANGED
@@ -1,3 +1,279 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**中文说明**](README_CN.md) | [**English**](README.md)
|
2 |
+
# Introduction
|
3 |
+
<br><br>
|
4 |
+
This project aims to provide a better Chinese CLIP model. The training data used in this project consists of publicly accessible image URLs and related Chinese text descriptions, totaling 400 million. After screening, we ultimately used 100 million data for training.
|
5 |
+
This project is produced by QQ-ARC Joint Lab, Tencent PCG.
|
6 |
+
<br><br>
|
7 |
+
|
8 |
+
# Models and Results
|
9 |
+
<span id="model_card"></span>
|
10 |
+
## Model Card
|
11 |
+
QA-CLIP currently has three different open-source models of different sizes, and their model information and download links are shown in the table below:
|
12 |
+
<table border="1" width="100%">
|
13 |
+
<tr align="center">
|
14 |
+
<th>Model</th><th>Ckp</th><th>Params</th><th>Vision</th><th>Params of Vision</th><th>Text</th><th>Params of Text</th><th>Resolution</th>
|
15 |
+
</tr>
|
16 |
+
<tr align="center">
|
17 |
+
<td>QA-CLIP<sub>RN50</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt">Download</a></td><td>77M</td><td>ResNet50</td><td>38M</td><td>RBT3</td><td>39M</td><td>224</td>
|
18 |
+
</tr>
|
19 |
+
<tr align="center">
|
20 |
+
<td>QA-CLIP<sub>ViT-B/16</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt">Download</a></td><td>188M</td><td>ViT-B/16</td><td>86M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
|
21 |
+
</tr>
|
22 |
+
<tr align="center">
|
23 |
+
<td>QA-CLIP<sub>ViT-L/14</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt">Download</a></td><td>406M</td><td>ViT-L/14</td><td>304M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
|
24 |
+
</tr>
|
25 |
+
</table>
|
26 |
+
<br>
|
27 |
+
|
28 |
+
## Results
|
29 |
+
We conducted zero-shot tests on [MUGE Retrieval](https://tianchi.aliyun.com/muge), [Flickr30K-CN](https://github.com/li-xirong/cross-lingual-cap), and [COCO-CN](https://github.com/li-xirong/coco-cn) datasets for image-text retrieval tasks. For the image zero-shot classification task, we tested on the ImageNet dataset. The test results are shown in the table below:
|
30 |
+
|
31 |
+
**Flickr30K-CN Zero-shot Retrieval (Official Test Set)**:
|
32 |
+
<table border="1" width="120%">
|
33 |
+
<tr align="center">
|
34 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
35 |
+
</tr>
|
36 |
+
<tr align="center">
|
37 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
38 |
+
</tr>
|
39 |
+
<tr align="center">
|
40 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.8</td><td>76.0</td><td>84.6</td><td>60.0</td><td>85.9</td><td>92.0</td>
|
41 |
+
</tr>
|
42 |
+
<tr align="center">
|
43 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.5</b></td><td><b>77.4</b></td><td><b>86.1</b></td><td><b>67.1</b></td><td><b>87.9</b></td><td><b>93.2</b></td>
|
44 |
+
</tr>
|
45 |
+
<tr align="center">
|
46 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.7</td><td>86.9</td><td>92.8</td><td>74.6</td><td>93.5</td><td>97.1</td>
|
47 |
+
</tr>
|
48 |
+
<tr align="center">
|
49 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>63.8</b></td><td><b>88.0</b></td><td><b>93.2</b></td><td><b>78.4</b></td><td><b>96.1</b></td><td><b>98.5</b></td>
|
50 |
+
</tr>
|
51 |
+
<tr align="center">
|
52 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>68.0</td><td>89.7</td><td>94.4</td><td>80.2</td><td>96.6</td><td>98.2</td>
|
53 |
+
</tr>
|
54 |
+
<tr align="center">
|
55 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td><b>69.7</b></td><td>90.1</td><td>94.8</td><td>84.8</td><td>97.7</td><td>99.1</td>
|
56 |
+
</tr>
|
57 |
+
<tr align="center">
|
58 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>69.3</td><td><b>90.3</b></td><td><b>94.7</b></td><td><b>85.3</b></td><td><b>97.9</b></td><td><b>99.2</b></td>
|
59 |
+
</tr>
|
60 |
+
</table>
|
61 |
+
<br>
|
62 |
+
|
63 |
+
**MUGE Zero-shot Retrieval (Official Validation Set)**:
|
64 |
+
<table border="1" width="120%">
|
65 |
+
<tr align="center">
|
66 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
67 |
+
</tr>
|
68 |
+
<tr align="center">
|
69 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
70 |
+
</tr>
|
71 |
+
<tr align="center">
|
72 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>42.6</td><td>68.5</td><td>78.0</td><td>30.0</td><td>56.2</td><td>66.9</td>
|
73 |
+
</tr>
|
74 |
+
<tr align="center">
|
75 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>44.0</b></td><td><b>69.9</b></td><td><b>79.5</b></td><td><b>32.4</b></td><td><b>59.5</b></td><td><b>70.3</b></td>
|
76 |
+
</tr>
|
77 |
+
<tr align="center">
|
78 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>52.1</td><td>76.7</td><td>84.4</td><td>38.7</td><td>65.6</td><td>75.1</td>
|
79 |
+
</tr>
|
80 |
+
<tr align="center">
|
81 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>53.2</b></td><td><b>77.7</b></td><td><b>85.1</b></td><td><b>40.7</b></td><td><b>68.2</b></td><td><b>77.2</b></td>
|
82 |
+
</tr>
|
83 |
+
<tr align="center">
|
84 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>56.4</td><td>79.8</td><td>86.2</td><td>42.6</td><td>69.8</td><td>78.6</td>
|
85 |
+
</tr>
|
86 |
+
<tr align="center">
|
87 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td>29.6</td><td>49.9</td><td>58.8</td><td>21.4</td><td>42.0</td><td>51.9</td>
|
88 |
+
</tr>
|
89 |
+
<tr align="center">
|
90 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>57.4</b></td><td><b>81.0</b></td><td><b>87.7</b></td><td><b>45.5</b></td><td><b>73.0</b></td><td><b>81.4</b></td>
|
91 |
+
</tr>
|
92 |
+
</table>
|
93 |
+
<br>
|
94 |
+
|
95 |
+
**COCO-CN Zero-shot Retrieval (Official Test Set)**:
|
96 |
+
<table border="1" width="120%">
|
97 |
+
<tr align="center">
|
98 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
99 |
+
</tr>
|
100 |
+
<tr align="center">
|
101 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
102 |
+
</tr>
|
103 |
+
<tr align="center">
|
104 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.1</td><td>81.3</td><td>90.5</td><td>50.9</td><td>81.1</td><td>90.5</td>
|
105 |
+
</tr>
|
106 |
+
<tr align="center">
|
107 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.1</b></td><td><b>82.5</b></td><td><b>91.7</b></td><td><b>56.7</b></td><td><b>85.2</b></td><td><b>92.9</b></td>
|
108 |
+
</tr>
|
109 |
+
<tr align="center">
|
110 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.2</td><td>87.1</td><td>94.9</td><td>56.3</td><td>84.0</td><td>93.3</td>
|
111 |
+
</tr>
|
112 |
+
<tr align="center">
|
113 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>62.9</b></td><td><b>87.7</b></td><td><b>94.7</b></td><td><b>61.5</b></td><td><b>87.6</b></td><td><b>94.8</b></td>
|
114 |
+
</tr>
|
115 |
+
<tr align="center">
|
116 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>64.9</td><td>88.8</td><td>94.2</td><td>60.6</td><td>84.4</td><td>93.1</td>
|
117 |
+
</tr>
|
118 |
+
<tr align="center">
|
119 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td>63.5</td><td>87.6</td><td>93.5</td><td>62.6</td><td><b>88.5</b></td><td><b>95.9</b></td>
|
120 |
+
</tr>
|
121 |
+
<tr align="center">
|
122 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>65.7</b></td><td><b>90.2</b></td><td><b>95.0</b></td><td><b>64.5</b></td><td>88.3</td><td>95.1</td>
|
123 |
+
</tr>
|
124 |
+
</table>
|
125 |
+
<br>
|
126 |
+
|
127 |
+
**Zero-shot Image Classification on ImageNet**:
|
128 |
+
<table border="1" width="120%">
|
129 |
+
<tr align="center">
|
130 |
+
<th>Task</th><th colspan="1">ImageNet</th>
|
131 |
+
</tr>
|
132 |
+
<tr align="center">
|
133 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>33.5</td>
|
134 |
+
</tr>
|
135 |
+
<tr align="center">
|
136 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>35.5</b></td>
|
137 |
+
</tr>
|
138 |
+
<tr align="center">
|
139 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>48.4</td>
|
140 |
+
</tr>
|
141 |
+
<tr align="center">
|
142 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>49.7</b></td>
|
143 |
+
</tr>
|
144 |
+
<tr align="center">
|
145 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>54.7</td>
|
146 |
+
</tr>
|
147 |
+
<tr align="center">
|
148 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>55.8</b></td>
|
149 |
+
</tr>
|
150 |
+
</table>
|
151 |
+
<br>
|
152 |
+
|
153 |
+
<br><br>
|
154 |
+
|
155 |
+
|
156 |
+
# Getting Started
|
157 |
+
## Installation Requirements
|
158 |
+
Environment configuration requirements:
|
159 |
+
|
160 |
+
* python >= 3.6.4
|
161 |
+
* pytorch >= 1.8.0 (with torchvision >= 0.9.0)
|
162 |
+
* CUDA Version >= 10.2
|
163 |
+
|
164 |
+
Install required packages:
|
165 |
+
```bash
|
166 |
+
cd /yourpath/QA-CLIP-main
|
167 |
+
pip install -r requirements.txt
|
168 |
+
```
|
169 |
+
|
170 |
+
## Inference Code
|
171 |
+
```bash
|
172 |
+
export PYTHONPATH=/yourpath/QA-CLIP-main
|
173 |
+
```
|
174 |
+
Inference code example:
|
175 |
+
```python
|
176 |
+
import torch
|
177 |
+
from PIL import Image
|
178 |
+
|
179 |
+
import clip as clip
|
180 |
+
from clip import load_from_name, available_models
|
181 |
+
print("Available models:", available_models())
|
182 |
+
# Available models: ['ViT-B-16', 'ViT-L-14', 'RN50']
|
183 |
+
|
184 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
185 |
+
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
|
186 |
+
model.eval()
|
187 |
+
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
|
188 |
+
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
image_features = model.encode_image(image)
|
192 |
+
text_features = model.encode_text(text)
|
193 |
+
# Normalize the features. Please use the normalized features for downstream tasks.
|
194 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
195 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
196 |
+
|
197 |
+
logits_per_image, logits_per_text = model.get_similarity(image, text)
|
198 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
199 |
+
|
200 |
+
print("Label probs:", probs)
|
201 |
+
```
|
202 |
+
<br><br>
|
203 |
+
|
204 |
+
## Prediction and Evaluation
|
205 |
+
|
206 |
+
### Download Image-text Retrieval Test Dataset
|
207 |
+
In Project <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>, the test set has already been preprocessed. Here is the download link they provided:
|
208 |
+
|
209 |
+
MUGE dataset:[download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)
|
210 |
+
|
211 |
+
Flickr30K-CN dataset:[download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)
|
212 |
+
|
213 |
+
Additionally, obtaining the [COCO-CN](https://github.com/li-xirong/coco-cn) dataset requires applying to the original author.
|
214 |
+
|
215 |
+
### Download ImageNet Dataset
|
216 |
+
Please download the raw data yourself,[Chinese Label](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label_cn.txt) and [English Label](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label.txt) are provided by Project <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>
|
217 |
+
### Image-text Retrieval Evaluation
|
218 |
+
The image-text retrieval evaluation code can be referred to as follows:
|
219 |
+
```bash
|
220 |
+
split=test # Designate the computation of features for the valid or test set
|
221 |
+
resume=your_ckp_path
|
222 |
+
DATAPATH=your_DATAPATH
|
223 |
+
dataset_name=Flickr30k-CN
|
224 |
+
# dataset_name=MUGE
|
225 |
+
|
226 |
+
python -u eval/extract_features.py \
|
227 |
+
--extract-image-feats \
|
228 |
+
--extract-text-feats \
|
229 |
+
--image-data="${DATAPATH}/datasets/${dataset_name}/lmdb/${split}/imgs" \
|
230 |
+
--text-data="${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl" \
|
231 |
+
--img-batch-size=32 \
|
232 |
+
--text-batch-size=32 \
|
233 |
+
--context-length=52 \
|
234 |
+
--resume=${resume} \
|
235 |
+
--vision-model=ViT-B-16 \
|
236 |
+
--text-model=RoBERTa-wwm-ext-base-chinese
|
237 |
+
|
238 |
+
python -u eval/make_topk_predictions.py \
|
239 |
+
--image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
|
240 |
+
--text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
|
241 |
+
--top-k=10 \
|
242 |
+
--eval-batch-size=32768 \
|
243 |
+
--output="${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl"
|
244 |
+
|
245 |
+
python -u eval/make_topk_predictions_tr.py \
|
246 |
+
--image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
|
247 |
+
--text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
|
248 |
+
--top-k=10 \
|
249 |
+
--eval-batch-size=32768 \
|
250 |
+
--output="${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl"
|
251 |
+
|
252 |
+
python eval/evaluation.py \
|
253 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl \
|
254 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl \
|
255 |
+
${DATAPATH}/datasets/${dataset_name}/output1.json
|
256 |
+
cat ${DATAPATH}/datasets/${dataset_name}/output1.json
|
257 |
+
|
258 |
+
python eval/transform_ir_annotation_to_tr.py \
|
259 |
+
--input ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl
|
260 |
+
|
261 |
+
python eval/evaluation_tr.py \
|
262 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_texts.tr.jsonl \
|
263 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl \
|
264 |
+
${DATAPATH}/datasets/${dataset_name}/output2.json
|
265 |
+
cat ${DATAPATH}/datasets/${dataset_name}/output2.json
|
266 |
+
```
|
267 |
+
|
268 |
+
### ImageNet Zero-shot Classification
|
269 |
+
The ImageNet zero-shot classification code can be referred to as follows
|
270 |
+
```bash
|
271 |
+
bash scripts/zeroshot_eval.sh 0 \
|
272 |
+
${DATAPATH} imagenet \
|
273 |
+
ViT-B-16 RoBERTa-wwm-ext-base-chinese \
|
274 |
+
./pretrained_weights/QA-CLIP-base.pt
|
275 |
+
```
|
276 |
+
# Acknowledgments
|
277 |
+
<br><br>
|
278 |
+
The project code is based on implementation of <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>, and we are very grateful for their outstanding open-source contributions.
|
279 |
+
<br><br>
|
README_CN.md
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**中文说明**](README_CN.md) | [**English**](README.md)
|
2 |
+
# 项目介绍
|
3 |
+
<br><br>
|
4 |
+
本项目旨在提供更好的中文CLIP模型。该项目使用的训练数据均为公开可访问的图像URL及相关中文文本描述,总量达到400M。经过筛选后,我们最终使用了100M的数据进行训练。
|
5 |
+
本项目于QQ-ARC Joint Lab, Tencent PCG完成
|
6 |
+
<br><br>
|
7 |
+
|
8 |
+
# 模型及实验
|
9 |
+
<span id="model_card"></span>
|
10 |
+
## 模型规模 & 下载链接
|
11 |
+
QA-CLIP目前开源3个不同规模,其模型信息和下载方式见下表:
|
12 |
+
|
13 |
+
<table border="1" width="100%">
|
14 |
+
<tr align="center">
|
15 |
+
<th>模型规模</th><th>下载链接</th><th>参数量</th><th>视觉侧骨架</th><th>视觉侧参数量</th><th>文本侧骨架</th><th>文本侧参数量</th><th>分辨率</th>
|
16 |
+
</tr>
|
17 |
+
<tr align="center">
|
18 |
+
<td>QA-CLIP<sub>RN50</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt">Download</a></td><td>77M</td><td>ResNet50</td><td>38M</td><td>RBT3</td><td>39M</td><td>224</td>
|
19 |
+
</tr>
|
20 |
+
<tr align="center">
|
21 |
+
<td>QA-CLIP<sub>ViT-B/16</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt">Download</a></td><td>188M</td><td>ViT-B/16</td><td>86M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
|
22 |
+
</tr>
|
23 |
+
<tr align="center">
|
24 |
+
<td>QA-CLIP<sub>ViT-L/14</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt">Download</a></td><td>406M</td><td>ViT-L/14</td><td>304M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
|
25 |
+
</tr>
|
26 |
+
</table>
|
27 |
+
<br>
|
28 |
+
|
29 |
+
## 实验结果
|
30 |
+
针对图文检索任务,我们在[MUGE Retrieval](https://tianchi.aliyun.com/muge)、[Flickr30K-CN](https://github.com/li-xirong/cross-lingual-cap)和[COCO-CN](https://github.com/li-xirong/coco-cn)上进行了zero-shot测试。
|
31 |
+
针对图像零样本分类任务,我们在ImageNet数据集上进行了测试。测试结果见下表:
|
32 |
+
|
33 |
+
|
34 |
+
**Flickr30K-CN Zero-shot Retrieval (Official Test Set)**:
|
35 |
+
<table border="1" width="120%">
|
36 |
+
<tr align="center">
|
37 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
38 |
+
</tr>
|
39 |
+
<tr align="center">
|
40 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
41 |
+
</tr>
|
42 |
+
<tr align="center">
|
43 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.8</td><td>76.0</td><td>84.6</td><td>60.0</td><td>85.9</td><td>92.0</td>
|
44 |
+
</tr>
|
45 |
+
<tr align="center">
|
46 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.5</b></td><td><b>77.4</b></td><td><b>86.1</b></td><td><b>67.1</b></td><td><b>87.9</b></td><td><b>93.2</b></td>
|
47 |
+
</tr>
|
48 |
+
<tr align="center">
|
49 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.7</td><td>86.9</td><td>92.8</td><td>74.6</td><td>93.5</td><td>97.1</td>
|
50 |
+
</tr>
|
51 |
+
<tr align="center">
|
52 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>63.8</b></td><td><b>88.0</b></td><td><b>93.2</b></td><td><b>78.4</b></td><td><b>96.1</b></td><td><b>98.5</b></td>
|
53 |
+
</tr>
|
54 |
+
<tr align="center">
|
55 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>68.0</td><td>89.7</td><td>94.4</td><td>80.2</td><td>96.6</td><td>98.2</td>
|
56 |
+
</tr>
|
57 |
+
<tr align="center">
|
58 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td><b>69.7</b></td><td>90.1</td><td>94.8</td><td>84.8</td><td>97.7</td><td>99.1</td>
|
59 |
+
</tr>
|
60 |
+
<tr align="center">
|
61 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>69.3</td><td><b>90.3</b></td><td><b>94.7</b></td><td><b>85.3</b></td><td><b>97.9</b></td><td><b>99.2</b></td>
|
62 |
+
</tr>
|
63 |
+
</table>
|
64 |
+
<br>
|
65 |
+
|
66 |
+
**MUGE Zero-shot Retrieval (Official Validation Set)**:
|
67 |
+
<table border="1" width="120%">
|
68 |
+
<tr align="center">
|
69 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
70 |
+
</tr>
|
71 |
+
<tr align="center">
|
72 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
73 |
+
</tr>
|
74 |
+
<tr align="center">
|
75 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>42.6</td><td>68.5</td><td>78.0</td><td>30.0</td><td>56.2</td><td>66.9</td>
|
76 |
+
</tr>
|
77 |
+
<tr align="center">
|
78 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>44.0</b></td><td><b>69.9</b></td><td><b>79.5</b></td><td><b>32.4</b></td><td><b>59.5</b></td><td><b>70.3</b></td>
|
79 |
+
</tr>
|
80 |
+
<tr align="center">
|
81 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>52.1</td><td>76.7</td><td>84.4</td><td>38.7</td><td>65.6</td><td>75.1</td>
|
82 |
+
</tr>
|
83 |
+
<tr align="center">
|
84 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>53.2</b></td><td><b>77.7</b></td><td><b>85.1</b></td><td><b>40.7</b></td><td><b>68.2</b></td><td><b>77.2</b></td>
|
85 |
+
</tr>
|
86 |
+
<tr align="center">
|
87 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>56.4</td><td>79.8</td><td>86.2</td><td>42.6</td><td>69.8</td><td>78.6</td>
|
88 |
+
</tr>
|
89 |
+
<tr align="center">
|
90 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td>29.6</td><td>49.9</td><td>58.8</td><td>21.4</td><td>42.0</td><td>51.9</td>
|
91 |
+
</tr>
|
92 |
+
<tr align="center">
|
93 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>57.4</b></td><td><b>81.0</b></td><td><b>87.7</b></td><td><b>45.5</b></td><td><b>73.0</b></td><td><b>81.4</b></td>
|
94 |
+
</tr>
|
95 |
+
</table>
|
96 |
+
<br>
|
97 |
+
|
98 |
+
**COCO-CN Zero-shot Retrieval (Official Test Set)**:
|
99 |
+
<table border="1" width="120%">
|
100 |
+
<tr align="center">
|
101 |
+
<th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
|
102 |
+
</tr>
|
103 |
+
<tr align="center">
|
104 |
+
<td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
|
105 |
+
</tr>
|
106 |
+
<tr align="center">
|
107 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.1</td><td>81.3</td><td>90.5</td><td>50.9</td><td>81.1</td><td>90.5</td>
|
108 |
+
</tr>
|
109 |
+
<tr align="center">
|
110 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.1</b></td><td><b>82.5</b></td><td><b>91.7</b></td><td><b>56.7</b></td><td><b>85.2</b></td><td><b>92.9</b></td>
|
111 |
+
</tr>
|
112 |
+
<tr align="center">
|
113 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.2</td><td>87.1</td><td>94.9</td><td>56.3</td><td>84.0</td><td>93.3</td>
|
114 |
+
</tr>
|
115 |
+
<tr align="center">
|
116 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>62.9</b></td><td><b>87.7</b></td><td><b>94.7</b></td><td><b>61.5</b></td><td><b>87.6</b></td><td><b>94.8</b></td>
|
117 |
+
</tr>
|
118 |
+
<tr align="center">
|
119 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>64.9</td><td>88.8</td><td>94.2</td><td>60.6</td><td>84.4</td><td>93.1</td>
|
120 |
+
</tr>
|
121 |
+
<tr align="center">
|
122 |
+
<td width="120%">AltClip<sub>ViT-L/14</sub></td><td>63.5</td><td>87.6</td><td>93.5</td><td>62.6</td><td><b>88.5</b></td><td><b>95.9</b></td>
|
123 |
+
</tr>
|
124 |
+
<tr align="center">
|
125 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>65.7</b></td><td><b>90.2</b></td><td><b>95.0</b></td><td><b>64.5</b></td><td>88.3</td><td>95.1</td>
|
126 |
+
</tr>
|
127 |
+
</table>
|
128 |
+
<br>
|
129 |
+
|
130 |
+
**Zero-shot Image Classification on ImageNet**:
|
131 |
+
<table border="1" width="120%">
|
132 |
+
<tr align="center">
|
133 |
+
<th>Task</th><th colspan="1">ImageNet</th>
|
134 |
+
</tr>
|
135 |
+
<tr align="center">
|
136 |
+
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>33.5</td>
|
137 |
+
</tr>
|
138 |
+
<tr align="center">
|
139 |
+
<td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>35.5</b></td>
|
140 |
+
</tr>
|
141 |
+
<tr align="center">
|
142 |
+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>48.4</td>
|
143 |
+
</tr>
|
144 |
+
<tr align="center">
|
145 |
+
<td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>49.7</b></td>
|
146 |
+
</tr>
|
147 |
+
<tr align="center">
|
148 |
+
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>54.7</td>
|
149 |
+
</tr>
|
150 |
+
<tr align="center">
|
151 |
+
<td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>55.8</b></td>
|
152 |
+
</tr>
|
153 |
+
</table>
|
154 |
+
<br>
|
155 |
+
|
156 |
+
<br><br>
|
157 |
+
|
158 |
+
|
159 |
+
# 使用教程
|
160 |
+
## 安装要求
|
161 |
+
环境配置要求:
|
162 |
+
|
163 |
+
* python >= 3.6.4
|
164 |
+
* pytorch >= 1.8.0 (with torchvision >= 0.9.0)
|
165 |
+
* CUDA Version >= 10.2
|
166 |
+
|
167 |
+
安装本项目所需库
|
168 |
+
```bash
|
169 |
+
cd /yourpath/QA-CLIP-main
|
170 |
+
pip install -r requirements.txt
|
171 |
+
```
|
172 |
+
|
173 |
+
## 推理代码
|
174 |
+
```bash
|
175 |
+
export PYTHONPATH=/yourpath/QA-CLIP-main
|
176 |
+
```
|
177 |
+
推理代码示例:
|
178 |
+
```python
|
179 |
+
import torch
|
180 |
+
from PIL import Image
|
181 |
+
|
182 |
+
import clip as clip
|
183 |
+
from clip import load_from_name, available_models
|
184 |
+
print("Available models:", available_models())
|
185 |
+
# Available models: ['ViT-B-16', 'ViT-L-14', 'RN50']
|
186 |
+
|
187 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
188 |
+
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
|
189 |
+
model.eval()
|
190 |
+
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
|
191 |
+
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
image_features = model.encode_image(image)
|
195 |
+
text_features = model.encode_text(text)
|
196 |
+
# 对特征进行归一化,请使用归一化后的图文特征用于下游任务
|
197 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
198 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
199 |
+
|
200 |
+
logits_per_image, logits_per_text = model.get_similarity(image, text)
|
201 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
202 |
+
|
203 |
+
print("Label probs:", probs)
|
204 |
+
```
|
205 |
+
<br><br>
|
206 |
+
|
207 |
+
## 预测及评估
|
208 |
+
|
209 |
+
### 图文检索测试数据集下载
|
210 |
+
<b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>项目中已经预处理好测试集,这是他们提供的下载链接:
|
211 |
+
|
212 |
+
MUGE数据:[下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)
|
213 |
+
|
214 |
+
Flickr30K-CN数据:[下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)
|
215 |
+
|
216 |
+
另外[COCO-CN](https://github.com/li-xirong/coco-cn)数据的获取需要向原作者进行申请
|
217 |
+
### ImageNet数据集下载
|
218 |
+
原始数据请自行下载,[中文标签](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label_cn.txt)和[英文标签](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label.txt)同样由<b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>项目提供
|
219 |
+
### 图文检索评估
|
220 |
+
图文检索评估代码可以参考如下:
|
221 |
+
```bash
|
222 |
+
split=test # 指定计算valid或test集特征
|
223 |
+
resume=your_ckp_path
|
224 |
+
DATAPATH=your_DATAPATH
|
225 |
+
dataset_name=Flickr30k-CN
|
226 |
+
# dataset_name=MUGE
|
227 |
+
|
228 |
+
python -u eval/extract_features.py \
|
229 |
+
--extract-image-feats \
|
230 |
+
--extract-text-feats \
|
231 |
+
--image-data="${DATAPATH}/datasets/${dataset_name}/lmdb/${split}/imgs" \
|
232 |
+
--text-data="${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl" \
|
233 |
+
--img-batch-size=32 \
|
234 |
+
--text-batch-size=32 \
|
235 |
+
--context-length=52 \
|
236 |
+
--resume=${resume} \
|
237 |
+
--vision-model=ViT-B-16 \
|
238 |
+
--text-model=RoBERTa-wwm-ext-base-chinese
|
239 |
+
|
240 |
+
python -u eval/make_topk_predictions.py \
|
241 |
+
--image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
|
242 |
+
--text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
|
243 |
+
--top-k=10 \
|
244 |
+
--eval-batch-size=32768 \
|
245 |
+
--output="${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl"
|
246 |
+
|
247 |
+
python -u eval/make_topk_predictions_tr.py \
|
248 |
+
--image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
|
249 |
+
--text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
|
250 |
+
--top-k=10 \
|
251 |
+
--eval-batch-size=32768 \
|
252 |
+
--output="${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl"
|
253 |
+
|
254 |
+
python eval/evaluation.py \
|
255 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl \
|
256 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl \
|
257 |
+
${DATAPATH}/datasets/${dataset_name}/output1.json
|
258 |
+
cat ${DATAPATH}/datasets/${dataset_name}/output1.json
|
259 |
+
|
260 |
+
python eval/transform_ir_annotation_to_tr.py \
|
261 |
+
--input ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl
|
262 |
+
|
263 |
+
python eval/evaluation_tr.py \
|
264 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_texts.tr.jsonl \
|
265 |
+
${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl \
|
266 |
+
${DATAPATH}/datasets/${dataset_name}/output2.json
|
267 |
+
cat ${DATAPATH}/datasets/${dataset_name}/output2.json
|
268 |
+
```
|
269 |
+
|
270 |
+
### ImageNet零样本分类
|
271 |
+
ImageNet零样本分类的代码参考如下
|
272 |
+
```bash
|
273 |
+
bash scripts/zeroshot_eval.sh 0 \
|
274 |
+
${DATAPATH} imagenet \
|
275 |
+
ViT-B-16 RoBERTa-wwm-ext-base-chinese \
|
276 |
+
./pretrained_weights/QA-CLIP-base.pt
|
277 |
+
```
|
278 |
+
# 致谢
|
279 |
+
<br><br>
|
280 |
+
项目代码基于<b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>实现,非常感谢他们优秀的开源工作。
|
281 |
+
<br><br>
|
clip/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bert_tokenizer import FullTokenizer
|
2 |
+
|
3 |
+
_tokenizer = FullTokenizer()
|
4 |
+
from .model import convert_state_dict
|
5 |
+
from .utils import load_from_name, available_models, tokenize, image_transform, load
|
clip/bert_tokenizer.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Tokenization classes."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import collections
|
23 |
+
import re
|
24 |
+
import unicodedata
|
25 |
+
import six
|
26 |
+
from functools import lru_cache
|
27 |
+
import os
|
28 |
+
|
29 |
+
@lru_cache()
|
30 |
+
def default_vocab():
|
31 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt")
|
32 |
+
|
33 |
+
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
34 |
+
"""Checks whether the casing config is consistent with the checkpoint name."""
|
35 |
+
|
36 |
+
# The casing has to be passed in by the user and there is no explicit check
|
37 |
+
# as to whether it matches the checkpoint. The casing information probably
|
38 |
+
# should have been stored in the bert_config.json file, but it's not, so
|
39 |
+
# we have to heuristically detect it to validate.
|
40 |
+
|
41 |
+
if not init_checkpoint:
|
42 |
+
return
|
43 |
+
|
44 |
+
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
45 |
+
if m is None:
|
46 |
+
return
|
47 |
+
|
48 |
+
model_name = m.group(1)
|
49 |
+
|
50 |
+
lower_models = [
|
51 |
+
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
52 |
+
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
53 |
+
]
|
54 |
+
|
55 |
+
cased_models = [
|
56 |
+
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
57 |
+
"multi_cased_L-12_H-768_A-12"
|
58 |
+
]
|
59 |
+
|
60 |
+
is_bad_config = False
|
61 |
+
if model_name in lower_models and not do_lower_case:
|
62 |
+
is_bad_config = True
|
63 |
+
actual_flag = "False"
|
64 |
+
case_name = "lowercased"
|
65 |
+
opposite_flag = "True"
|
66 |
+
|
67 |
+
if model_name in cased_models and do_lower_case:
|
68 |
+
is_bad_config = True
|
69 |
+
actual_flag = "True"
|
70 |
+
case_name = "cased"
|
71 |
+
opposite_flag = "False"
|
72 |
+
|
73 |
+
if is_bad_config:
|
74 |
+
raise ValueError(
|
75 |
+
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
76 |
+
"However, `%s` seems to be a %s model, so you "
|
77 |
+
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
78 |
+
"how the model was pre-training. If this error is wrong, please "
|
79 |
+
"just comment out this check." % (actual_flag, init_checkpoint,
|
80 |
+
model_name, case_name, opposite_flag))
|
81 |
+
|
82 |
+
|
83 |
+
def convert_to_unicode(text):
|
84 |
+
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
85 |
+
if six.PY3:
|
86 |
+
if isinstance(text, str):
|
87 |
+
return text
|
88 |
+
elif isinstance(text, bytes):
|
89 |
+
return text.decode("utf-8", "ignore")
|
90 |
+
else:
|
91 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
92 |
+
elif six.PY2:
|
93 |
+
if isinstance(text, str):
|
94 |
+
return text.decode("utf-8", "ignore")
|
95 |
+
elif isinstance(text, unicode):
|
96 |
+
return text
|
97 |
+
else:
|
98 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
99 |
+
else:
|
100 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
101 |
+
|
102 |
+
|
103 |
+
def printable_text(text):
|
104 |
+
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
105 |
+
|
106 |
+
# These functions want `str` for both Python2 and Python3, but in one case
|
107 |
+
# it's a Unicode string and in the other it's a byte string.
|
108 |
+
if six.PY3:
|
109 |
+
if isinstance(text, str):
|
110 |
+
return text
|
111 |
+
elif isinstance(text, bytes):
|
112 |
+
return text.decode("utf-8", "ignore")
|
113 |
+
else:
|
114 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
115 |
+
elif six.PY2:
|
116 |
+
if isinstance(text, str):
|
117 |
+
return text
|
118 |
+
elif isinstance(text, unicode):
|
119 |
+
return text.encode("utf-8")
|
120 |
+
else:
|
121 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
122 |
+
else:
|
123 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
124 |
+
|
125 |
+
|
126 |
+
def load_vocab(vocab_file):
|
127 |
+
"""Loads a vocabulary file into a dictionary."""
|
128 |
+
vocab = collections.OrderedDict()
|
129 |
+
index = 0
|
130 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
131 |
+
while True:
|
132 |
+
token = convert_to_unicode(reader.readline())
|
133 |
+
if not token:
|
134 |
+
break
|
135 |
+
token = token.strip()
|
136 |
+
vocab[token] = index
|
137 |
+
index += 1
|
138 |
+
return vocab
|
139 |
+
|
140 |
+
|
141 |
+
def convert_by_vocab(vocab, items):
|
142 |
+
"""Converts a sequence of [tokens|ids] using the vocab."""
|
143 |
+
output = []
|
144 |
+
for item in items:
|
145 |
+
output.append(vocab[item])
|
146 |
+
return output
|
147 |
+
|
148 |
+
|
149 |
+
def convert_tokens_to_ids(vocab, tokens):
|
150 |
+
return convert_by_vocab(vocab, tokens)
|
151 |
+
|
152 |
+
|
153 |
+
def convert_ids_to_tokens(inv_vocab, ids):
|
154 |
+
return convert_by_vocab(inv_vocab, ids)
|
155 |
+
|
156 |
+
|
157 |
+
def whitespace_tokenize(text):
|
158 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
159 |
+
text = text.strip()
|
160 |
+
if not text:
|
161 |
+
return []
|
162 |
+
tokens = text.split()
|
163 |
+
return tokens
|
164 |
+
|
165 |
+
|
166 |
+
class FullTokenizer(object):
|
167 |
+
"""Runs end-to-end tokenziation."""
|
168 |
+
|
169 |
+
def __init__(self, vocab_file=default_vocab(), do_lower_case=True):
|
170 |
+
self.vocab = load_vocab(vocab_file)
|
171 |
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
172 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
173 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
174 |
+
|
175 |
+
def tokenize(self, text):
|
176 |
+
split_tokens = []
|
177 |
+
for token in self.basic_tokenizer.tokenize(text):
|
178 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
179 |
+
split_tokens.append(sub_token)
|
180 |
+
|
181 |
+
return split_tokens
|
182 |
+
|
183 |
+
def convert_tokens_to_ids(self, tokens):
|
184 |
+
return convert_by_vocab(self.vocab, tokens)
|
185 |
+
|
186 |
+
def convert_ids_to_tokens(self, ids):
|
187 |
+
return convert_by_vocab(self.inv_vocab, ids)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
191 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
192 |
+
|
193 |
+
def clean_up_tokenization(out_string):
|
194 |
+
""" Clean up a list of simple English tokenization artifacts
|
195 |
+
like spaces before punctuations and abreviated forms.
|
196 |
+
"""
|
197 |
+
out_string = (
|
198 |
+
out_string.replace(" .", ".")
|
199 |
+
.replace(" ?", "?")
|
200 |
+
.replace(" !", "!")
|
201 |
+
.replace(" ,", ",")
|
202 |
+
.replace(" ' ", "'")
|
203 |
+
.replace(" n't", "n't")
|
204 |
+
.replace(" 'm", "'m")
|
205 |
+
.replace(" 's", "'s")
|
206 |
+
.replace(" 've", "'ve")
|
207 |
+
.replace(" 're", "'re")
|
208 |
+
)
|
209 |
+
return out_string
|
210 |
+
|
211 |
+
text = ' '.join(tokens).replace(' ##', '').strip()
|
212 |
+
if clean_up_tokenization_spaces:
|
213 |
+
clean_text = clean_up_tokenization(text)
|
214 |
+
return clean_text
|
215 |
+
else:
|
216 |
+
return text
|
217 |
+
|
218 |
+
def vocab_size(self):
|
219 |
+
return len(self.vocab)
|
220 |
+
|
221 |
+
|
222 |
+
class BasicTokenizer(object):
|
223 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
224 |
+
|
225 |
+
def __init__(self, do_lower_case=True):
|
226 |
+
"""Constructs a BasicTokenizer.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
do_lower_case: Whether to lower case the input.
|
230 |
+
"""
|
231 |
+
self.do_lower_case = do_lower_case
|
232 |
+
|
233 |
+
def tokenize(self, text):
|
234 |
+
"""Tokenizes a piece of text."""
|
235 |
+
text = convert_to_unicode(text)
|
236 |
+
text = self._clean_text(text)
|
237 |
+
|
238 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
239 |
+
# models. This is also applied to the English models now, but it doesn't
|
240 |
+
# matter since the English models were not trained on any Chinese data
|
241 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
242 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
243 |
+
# words in the English Wikipedia.).
|
244 |
+
text = self._tokenize_chinese_chars(text)
|
245 |
+
|
246 |
+
orig_tokens = whitespace_tokenize(text)
|
247 |
+
split_tokens = []
|
248 |
+
for token in orig_tokens:
|
249 |
+
if self.do_lower_case:
|
250 |
+
token = token.lower()
|
251 |
+
token = self._run_strip_accents(token)
|
252 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
253 |
+
|
254 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
255 |
+
return output_tokens
|
256 |
+
|
257 |
+
def _run_strip_accents(self, text):
|
258 |
+
"""Strips accents from a piece of text."""
|
259 |
+
text = unicodedata.normalize("NFD", text)
|
260 |
+
output = []
|
261 |
+
for char in text:
|
262 |
+
cat = unicodedata.category(char)
|
263 |
+
if cat == "Mn":
|
264 |
+
continue
|
265 |
+
output.append(char)
|
266 |
+
return "".join(output)
|
267 |
+
|
268 |
+
def _run_split_on_punc(self, text):
|
269 |
+
"""Splits punctuation on a piece of text."""
|
270 |
+
chars = list(text)
|
271 |
+
i = 0
|
272 |
+
start_new_word = True
|
273 |
+
output = []
|
274 |
+
while i < len(chars):
|
275 |
+
char = chars[i]
|
276 |
+
if _is_punctuation(char):
|
277 |
+
output.append([char])
|
278 |
+
start_new_word = True
|
279 |
+
else:
|
280 |
+
if start_new_word:
|
281 |
+
output.append([])
|
282 |
+
start_new_word = False
|
283 |
+
output[-1].append(char)
|
284 |
+
i += 1
|
285 |
+
|
286 |
+
return ["".join(x) for x in output]
|
287 |
+
|
288 |
+
def _tokenize_chinese_chars(self, text):
|
289 |
+
"""Adds whitespace around any CJK character."""
|
290 |
+
output = []
|
291 |
+
for char in text:
|
292 |
+
cp = ord(char)
|
293 |
+
if self._is_chinese_char(cp):
|
294 |
+
output.append(" ")
|
295 |
+
output.append(char)
|
296 |
+
output.append(" ")
|
297 |
+
else:
|
298 |
+
output.append(char)
|
299 |
+
return "".join(output)
|
300 |
+
|
301 |
+
def _is_chinese_char(self, cp):
|
302 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
303 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
304 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
305 |
+
#
|
306 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
307 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
308 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
309 |
+
# space-separated words, so they are not treated specially and handled
|
310 |
+
# like the all of the other languages.
|
311 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
312 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
313 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
314 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
315 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
316 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
317 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
318 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
319 |
+
return True
|
320 |
+
|
321 |
+
return False
|
322 |
+
|
323 |
+
def _clean_text(self, text):
|
324 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
325 |
+
output = []
|
326 |
+
for char in text:
|
327 |
+
cp = ord(char)
|
328 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
329 |
+
continue
|
330 |
+
if _is_whitespace(char):
|
331 |
+
output.append(" ")
|
332 |
+
else:
|
333 |
+
output.append(char)
|
334 |
+
return "".join(output)
|
335 |
+
|
336 |
+
|
337 |
+
class WordpieceTokenizer(object):
|
338 |
+
"""Runs WordPiece tokenziation."""
|
339 |
+
|
340 |
+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
341 |
+
self.vocab = vocab
|
342 |
+
self.unk_token = unk_token
|
343 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
344 |
+
|
345 |
+
def tokenize(self, text):
|
346 |
+
"""Tokenizes a piece of text into its word pieces.
|
347 |
+
|
348 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
349 |
+
using the given vocabulary.
|
350 |
+
|
351 |
+
For example:
|
352 |
+
input = "unaffable"
|
353 |
+
output = ["un", "##aff", "##able"]
|
354 |
+
|
355 |
+
Args:
|
356 |
+
text: A single token or whitespace separated tokens. This should have
|
357 |
+
already been passed through `BasicTokenizer.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
A list of wordpiece tokens.
|
361 |
+
"""
|
362 |
+
|
363 |
+
text = convert_to_unicode(text)
|
364 |
+
|
365 |
+
output_tokens = []
|
366 |
+
for token in whitespace_tokenize(text):
|
367 |
+
chars = list(token)
|
368 |
+
if len(chars) > self.max_input_chars_per_word:
|
369 |
+
output_tokens.append(self.unk_token)
|
370 |
+
continue
|
371 |
+
|
372 |
+
is_bad = False
|
373 |
+
start = 0
|
374 |
+
sub_tokens = []
|
375 |
+
while start < len(chars):
|
376 |
+
end = len(chars)
|
377 |
+
cur_substr = None
|
378 |
+
while start < end:
|
379 |
+
substr = "".join(chars[start:end])
|
380 |
+
if start > 0:
|
381 |
+
substr = "##" + substr
|
382 |
+
if substr in self.vocab:
|
383 |
+
cur_substr = substr
|
384 |
+
break
|
385 |
+
end -= 1
|
386 |
+
if cur_substr is None:
|
387 |
+
is_bad = True
|
388 |
+
break
|
389 |
+
sub_tokens.append(cur_substr)
|
390 |
+
start = end
|
391 |
+
|
392 |
+
if is_bad:
|
393 |
+
output_tokens.append(self.unk_token)
|
394 |
+
else:
|
395 |
+
output_tokens.extend(sub_tokens)
|
396 |
+
return output_tokens
|
397 |
+
|
398 |
+
|
399 |
+
def _is_whitespace(char):
|
400 |
+
"""Checks whether `chars` is a whitespace character."""
|
401 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
402 |
+
# as whitespace since they are generally considered as such.
|
403 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
404 |
+
return True
|
405 |
+
cat = unicodedata.category(char)
|
406 |
+
if cat == "Zs":
|
407 |
+
return True
|
408 |
+
return False
|
409 |
+
|
410 |
+
|
411 |
+
def _is_control(char):
|
412 |
+
"""Checks whether `chars` is a control character."""
|
413 |
+
# These are technically control characters but we count them as whitespace
|
414 |
+
# characters.
|
415 |
+
if char == "\t" or char == "\n" or char == "\r":
|
416 |
+
return False
|
417 |
+
cat = unicodedata.category(char)
|
418 |
+
if cat in ("Cc", "Cf"):
|
419 |
+
return True
|
420 |
+
return False
|
421 |
+
|
422 |
+
|
423 |
+
def _is_punctuation(char):
|
424 |
+
"""Checks whether `chars` is a punctuation character."""
|
425 |
+
cp = ord(char)
|
426 |
+
# We treat all non-letter/number ASCII as punctuation.
|
427 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
428 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
429 |
+
# consistency.
|
430 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
431 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
432 |
+
return True
|
433 |
+
cat = unicodedata.category(char)
|
434 |
+
if cat.startswith("P"):
|
435 |
+
return True
|
436 |
+
return False
|
clip/configuration_bert.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" BERT model configuration """
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class BertConfig(object):
|
26 |
+
r"""
|
27 |
+
:class:`~transformers.BertConfig` is the configuration class to store the configuration of a
|
28 |
+
`BertModel`.
|
29 |
+
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
33 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
34 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
35 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
36 |
+
the Transformer encoder.
|
37 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
38 |
+
layer in the Transformer encoder.
|
39 |
+
hidden_act: The non-linear activation function (function or string) in the
|
40 |
+
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
41 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
42 |
+
layers in the embeddings, encoder, and pooler.
|
43 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
44 |
+
probabilities.
|
45 |
+
max_position_embeddings: The maximum sequence length that this model might
|
46 |
+
ever be used with. Typically set this to something large just in case
|
47 |
+
(e.g., 512 or 1024 or 2048).
|
48 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
49 |
+
`BertModel`.
|
50 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
51 |
+
initializing all weight matrices.
|
52 |
+
layer_norm_eps: The epsilon used by LayerNorm.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
vocab_size_or_config_json_file=30522,
|
57 |
+
hidden_size=768,
|
58 |
+
num_hidden_layers=12,
|
59 |
+
num_attention_heads=12,
|
60 |
+
intermediate_size=3072,
|
61 |
+
hidden_act="gelu",
|
62 |
+
hidden_dropout_prob=0.1,
|
63 |
+
attention_probs_dropout_prob=0.1,
|
64 |
+
max_position_embeddings=512,
|
65 |
+
type_vocab_size=2,
|
66 |
+
initializer_range=0.02,
|
67 |
+
layer_norm_eps=1e-12,
|
68 |
+
output_attentions=False,
|
69 |
+
output_hidden_states=False,
|
70 |
+
use_flash_attention=False
|
71 |
+
):
|
72 |
+
self.vocab_size = vocab_size_or_config_json_file
|
73 |
+
self.hidden_size = hidden_size
|
74 |
+
self.num_hidden_layers = num_hidden_layers
|
75 |
+
self.num_attention_heads = num_attention_heads
|
76 |
+
self.hidden_act = hidden_act
|
77 |
+
self.intermediate_size = intermediate_size
|
78 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
79 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
80 |
+
self.max_position_embeddings = max_position_embeddings
|
81 |
+
self.type_vocab_size = type_vocab_size
|
82 |
+
self.initializer_range = initializer_range
|
83 |
+
self.layer_norm_eps = layer_norm_eps
|
84 |
+
self.output_attentions = output_attentions
|
85 |
+
self.output_hidden_states = output_hidden_states
|
86 |
+
self.use_flash_attention = use_flash_attention
|
clip/model.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
from itertools import repeat
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
import math
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from torch.utils.checkpoint import checkpoint
|
13 |
+
|
14 |
+
import importlib.util
|
15 |
+
if importlib.util.find_spec('flash_attn'):
|
16 |
+
FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA
|
17 |
+
|
18 |
+
from clip import _tokenizer
|
19 |
+
from clip.configuration_bert import BertConfig
|
20 |
+
from clip.modeling_bert import BertModel
|
21 |
+
|
22 |
+
try:
|
23 |
+
from transformers import CLIPTextModelWithProjection
|
24 |
+
except:
|
25 |
+
pass
|
26 |
+
|
27 |
+
class Bottleneck(nn.Module):
|
28 |
+
expansion = 4
|
29 |
+
|
30 |
+
def __init__(self, inplanes, planes, stride=1):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
34 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
35 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
36 |
+
|
37 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
38 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
39 |
+
|
40 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
41 |
+
|
42 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
43 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
44 |
+
|
45 |
+
self.relu = nn.ReLU(inplace=True)
|
46 |
+
self.downsample = None
|
47 |
+
self.stride = stride
|
48 |
+
|
49 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
50 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
51 |
+
self.downsample = nn.Sequential(OrderedDict([
|
52 |
+
("-1", nn.AvgPool2d(stride)),
|
53 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
54 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
55 |
+
]))
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor):
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
61 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
62 |
+
out = self.avgpool(out)
|
63 |
+
out = self.bn3(self.conv3(out))
|
64 |
+
|
65 |
+
if self.downsample is not None:
|
66 |
+
identity = self.downsample(x)
|
67 |
+
|
68 |
+
out += identity
|
69 |
+
out = self.relu(out)
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
class AttentionPool2d(nn.Module):
|
74 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
75 |
+
super().__init__()
|
76 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
77 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
78 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
79 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
80 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
81 |
+
self.num_heads = num_heads
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
85 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
86 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
87 |
+
x, _ = F.multi_head_attention_forward(
|
88 |
+
query=x, key=x, value=x,
|
89 |
+
embed_dim_to_check=x.shape[-1],
|
90 |
+
num_heads=self.num_heads,
|
91 |
+
q_proj_weight=self.q_proj.weight,
|
92 |
+
k_proj_weight=self.k_proj.weight,
|
93 |
+
v_proj_weight=self.v_proj.weight,
|
94 |
+
in_proj_weight=None,
|
95 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
96 |
+
bias_k=None,
|
97 |
+
bias_v=None,
|
98 |
+
add_zero_attn=False,
|
99 |
+
dropout_p=0,
|
100 |
+
out_proj_weight=self.c_proj.weight,
|
101 |
+
out_proj_bias=self.c_proj.bias,
|
102 |
+
use_separate_proj_weight=True,
|
103 |
+
training=self.training,
|
104 |
+
need_weights=False
|
105 |
+
)
|
106 |
+
|
107 |
+
return x[0]
|
108 |
+
|
109 |
+
|
110 |
+
class ModifiedResNet(nn.Module):
|
111 |
+
"""
|
112 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
113 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
114 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
115 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
119 |
+
super().__init__()
|
120 |
+
self.output_dim = output_dim
|
121 |
+
self.input_resolution = input_resolution
|
122 |
+
|
123 |
+
# the 3-layer stem
|
124 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
125 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
126 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
127 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
128 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
129 |
+
self.bn3 = nn.BatchNorm2d(width)
|
130 |
+
self.avgpool = nn.AvgPool2d(2)
|
131 |
+
self.relu = nn.ReLU(inplace=True)
|
132 |
+
|
133 |
+
# residual layers
|
134 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
135 |
+
self.layer1 = self._make_layer(width, layers[0])
|
136 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
137 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
138 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
139 |
+
|
140 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
141 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
142 |
+
|
143 |
+
def _make_layer(self, planes, blocks, stride=1):
|
144 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
145 |
+
|
146 |
+
self._inplanes = planes * Bottleneck.expansion
|
147 |
+
for _ in range(1, blocks):
|
148 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
149 |
+
|
150 |
+
return nn.Sequential(*layers)
|
151 |
+
|
152 |
+
@torch.jit.ignore
|
153 |
+
def set_grad_checkpointing(self, enable=True):
|
154 |
+
# FIXME support for non-transformer
|
155 |
+
pass
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
def stem(x):
|
159 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
160 |
+
x = self.relu(bn(conv(x)))
|
161 |
+
x = self.avgpool(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
x = x.type(self.conv1.weight.dtype)
|
165 |
+
x = stem(x)
|
166 |
+
x = self.layer1(x)
|
167 |
+
x = self.layer2(x)
|
168 |
+
x = self.layer3(x)
|
169 |
+
x = self.layer4(x)
|
170 |
+
x = self.attnpool(x)
|
171 |
+
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class LayerNorm(nn.LayerNorm):
|
176 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
177 |
+
|
178 |
+
def forward(self, x: torch.Tensor):
|
179 |
+
orig_type = x.dtype
|
180 |
+
ret = super().forward(x.type(torch.float32))
|
181 |
+
return ret.type(orig_type)
|
182 |
+
|
183 |
+
|
184 |
+
class QuickGELU(nn.Module):
|
185 |
+
def forward(self, x: torch.Tensor):
|
186 |
+
return x * torch.sigmoid(1.702 * x)
|
187 |
+
|
188 |
+
|
189 |
+
class ResidualAttentionBlock(nn.Module):
|
190 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.attn = nn.MultiheadAttention(d_model, n_head) if not use_flash_attention else FlashMHA(d_model, n_head)
|
194 |
+
self.ln_1 = LayerNorm(d_model)
|
195 |
+
self.mlp = nn.Sequential(OrderedDict([
|
196 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
197 |
+
("gelu", QuickGELU()),
|
198 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
199 |
+
]))
|
200 |
+
self.ln_2 = LayerNorm(d_model)
|
201 |
+
self.attn_mask = attn_mask
|
202 |
+
self.use_flash_attention = use_flash_attention
|
203 |
+
|
204 |
+
def attention(self, x: torch.Tensor):
|
205 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
206 |
+
if self.use_flash_attention:
|
207 |
+
# Batch first is needed for FlashAttention. See https://github.com/HazyResearch/flash-attention/issues/84 for more information.
|
208 |
+
return self.attn(x.transpose(1, 0))[0].transpose(1, 0)
|
209 |
+
else:
|
210 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor):
|
213 |
+
x = x + self.attention(self.ln_1(x))
|
214 |
+
x = x + self.mlp(self.ln_2(x))
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
|
220 |
+
super().__init__()
|
221 |
+
self.width = width
|
222 |
+
self.layers = layers
|
223 |
+
self.grad_checkpointing = False
|
224 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_flash_attention) for _ in range(layers)])
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor):
|
227 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
228 |
+
for r in self.resblocks:
|
229 |
+
x = checkpoint(r, x)
|
230 |
+
return x
|
231 |
+
return self.resblocks(x)
|
232 |
+
|
233 |
+
|
234 |
+
class VisualTransformer(nn.Module):
|
235 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, use_flash_attention: bool = False):
|
236 |
+
super().__init__()
|
237 |
+
self.input_resolution = input_resolution
|
238 |
+
self.grid_size = (self.input_resolution // patch_size, self.input_resolution // patch_size)
|
239 |
+
self.output_dim = output_dim
|
240 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
241 |
+
|
242 |
+
scale = width ** -0.5
|
243 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
244 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
245 |
+
self.ln_pre = LayerNorm(width)
|
246 |
+
|
247 |
+
self.transformer = Transformer(width, layers, heads, use_flash_attention=use_flash_attention)
|
248 |
+
|
249 |
+
self.ln_post = LayerNorm(width)
|
250 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
251 |
+
|
252 |
+
@torch.jit.ignore
|
253 |
+
def set_grad_checkpointing(self, enable=True):
|
254 |
+
self.transformer.grad_checkpointing = enable
|
255 |
+
|
256 |
+
def random_masking(self, x, mask_ratio):
|
257 |
+
N, L, D = x.shape # batch, length, dim
|
258 |
+
len_keep = int((L - 1) * (1 - mask_ratio))
|
259 |
+
|
260 |
+
noise = torch.rand(N, L - 1, device=x.device)
|
261 |
+
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device,
|
262 |
+
dtype=int)
|
263 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
264 |
+
|
265 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
266 |
+
|
267 |
+
x0 = x[:, 0, :]
|
268 |
+
x0 = x0.reshape(N, 1, D)
|
269 |
+
x_masked_add = torch.cat([x0, x_masked], axis=1)
|
270 |
+
return x_masked_add
|
271 |
+
|
272 |
+
def forward(self, x: torch.Tensor, mask_ratio: float = 0.0):
|
273 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
274 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
275 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
276 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
277 |
+
x = x + self.positional_embedding.to(x.dtype)
|
278 |
+
if mask_ratio != 0:
|
279 |
+
x = self.random_masking(x, mask_ratio)
|
280 |
+
x = self.ln_pre(x)
|
281 |
+
|
282 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
283 |
+
x = self.transformer(x)
|
284 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
285 |
+
|
286 |
+
x = self.ln_post(x[:, 0, :])
|
287 |
+
|
288 |
+
if self.proj is not None:
|
289 |
+
x = x @ self.proj
|
290 |
+
|
291 |
+
return x
|
292 |
+
|
293 |
+
|
294 |
+
class CLIP(nn.Module):
|
295 |
+
def __init__(self,
|
296 |
+
embed_dim: int,
|
297 |
+
# vision
|
298 |
+
image_resolution: int,
|
299 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
300 |
+
vision_width: int,
|
301 |
+
vision_patch_size: int,
|
302 |
+
# text
|
303 |
+
vocab_size: int,
|
304 |
+
text_attention_probs_dropout_prob: float,
|
305 |
+
text_hidden_act: str,
|
306 |
+
text_hidden_dropout_prob: float,
|
307 |
+
text_hidden_size: int,
|
308 |
+
text_initializer_range: float,
|
309 |
+
text_intermediate_size: int,
|
310 |
+
text_max_position_embeddings: int,
|
311 |
+
text_num_attention_heads: int,
|
312 |
+
text_num_hidden_layers: int,
|
313 |
+
text_type_vocab_size: int,
|
314 |
+
tokenizer = _tokenizer,
|
315 |
+
# vision head width, added this param for ViT-H
|
316 |
+
vision_head_width: int = 64,
|
317 |
+
use_flash_attention: bool = False,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
if isinstance(vision_layers, (tuple, list)):
|
322 |
+
vision_heads = vision_width * 32 // vision_head_width
|
323 |
+
self.visual = ModifiedResNet(
|
324 |
+
layers=vision_layers,
|
325 |
+
output_dim=embed_dim,
|
326 |
+
heads=vision_heads,
|
327 |
+
input_resolution=image_resolution,
|
328 |
+
width=vision_width
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
vision_heads = vision_width // vision_head_width
|
332 |
+
self.visual = VisualTransformer(
|
333 |
+
input_resolution=image_resolution,
|
334 |
+
patch_size=vision_patch_size,
|
335 |
+
width=vision_width,
|
336 |
+
layers=vision_layers,
|
337 |
+
heads=vision_heads,
|
338 |
+
output_dim=embed_dim,
|
339 |
+
use_flash_attention=use_flash_attention
|
340 |
+
)
|
341 |
+
|
342 |
+
self.bert_config = BertConfig(
|
343 |
+
vocab_size_or_config_json_file=vocab_size,
|
344 |
+
hidden_size=text_hidden_size,
|
345 |
+
num_hidden_layers=text_num_hidden_layers,
|
346 |
+
num_attention_heads=text_num_attention_heads,
|
347 |
+
intermediate_size=text_intermediate_size,
|
348 |
+
hidden_act=text_hidden_act,
|
349 |
+
hidden_dropout_prob=text_hidden_dropout_prob,
|
350 |
+
attention_probs_dropout_prob=text_attention_probs_dropout_prob,
|
351 |
+
max_position_embeddings=text_max_position_embeddings,
|
352 |
+
type_vocab_size=text_type_vocab_size,
|
353 |
+
initializer_range=text_initializer_range,
|
354 |
+
layer_norm_eps=1e-12,
|
355 |
+
use_flash_attention=use_flash_attention
|
356 |
+
)
|
357 |
+
self.bert = BertModel(self.bert_config)
|
358 |
+
|
359 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
|
360 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
361 |
+
|
362 |
+
self.tokenizer = tokenizer
|
363 |
+
|
364 |
+
self.initialize_parameters()
|
365 |
+
|
366 |
+
def initialize_parameters(self):
|
367 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
368 |
+
|
369 |
+
if isinstance(self.visual, ModifiedResNet):
|
370 |
+
if self.visual.attnpool is not None:
|
371 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
372 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
373 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
374 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
375 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
376 |
+
|
377 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
378 |
+
for name, param in resnet_block.named_parameters():
|
379 |
+
if name.endswith("bn3.weight"):
|
380 |
+
nn.init.zeros_(param)
|
381 |
+
|
382 |
+
if self.text_projection is not None:
|
383 |
+
nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
|
384 |
+
|
385 |
+
@torch.jit.ignore
|
386 |
+
def set_grad_checkpointing(self, enable=True):
|
387 |
+
self.visual.set_grad_checkpointing(enable)
|
388 |
+
self.bert.set_grad_checkpointing(enable)
|
389 |
+
|
390 |
+
@property
|
391 |
+
def dtype(self):
|
392 |
+
return self.visual.conv1.weight.dtype
|
393 |
+
|
394 |
+
def encode_image(self, image, mask_ratio=0):
|
395 |
+
if isinstance(self.visual, ModifiedResNet):
|
396 |
+
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
|
397 |
+
return self.visual(image.type(self.dtype))
|
398 |
+
return self.visual(image.type(self.dtype), mask_ratio)
|
399 |
+
|
400 |
+
def encode_text(self, text):
|
401 |
+
pad_index = self.tokenizer.vocab['[PAD]']
|
402 |
+
attn_mask = text.ne(pad_index).type(self.dtype)
|
403 |
+
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
|
404 |
+
return x[:, 0, :] @ self.text_projection
|
405 |
+
|
406 |
+
def forward(self, image, text, mask_ratio=0):
|
407 |
+
assert image is not None or text is not None, "text and image cannot both be None!"
|
408 |
+
|
409 |
+
if image is None:
|
410 |
+
return self.encode_text(text)
|
411 |
+
elif text is None:
|
412 |
+
return self.encode_image(image, mask_ratio)
|
413 |
+
image_features = self.encode_image(image, mask_ratio)
|
414 |
+
text_features = self.encode_text(text)
|
415 |
+
|
416 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
417 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
418 |
+
|
419 |
+
return image_features, text_features, self.logit_scale.exp()
|
420 |
+
|
421 |
+
def get_similarity(self, image, text):
|
422 |
+
image_features = self.encode_image(image)
|
423 |
+
text_features = self.encode_text(text)
|
424 |
+
|
425 |
+
# normalized features
|
426 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
427 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
428 |
+
|
429 |
+
# cosine similarity as logits
|
430 |
+
logit_scale = self.logit_scale.exp()
|
431 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
432 |
+
logits_per_text = logits_per_image.t()
|
433 |
+
|
434 |
+
# shape = [global_batch_size, global_batch_size]
|
435 |
+
return logits_per_image, logits_per_text
|
436 |
+
|
437 |
+
class CLIPWithTwoTextEncoder(nn.Module):
|
438 |
+
def __init__(self,
|
439 |
+
embed_dim: int,
|
440 |
+
# vision
|
441 |
+
image_resolution: int,
|
442 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
443 |
+
vision_width: int,
|
444 |
+
vision_patch_size: int,
|
445 |
+
# text
|
446 |
+
vocab_size: int,
|
447 |
+
text_attention_probs_dropout_prob: float,
|
448 |
+
text_hidden_act: str,
|
449 |
+
text_hidden_dropout_prob: float,
|
450 |
+
text_hidden_size: int,
|
451 |
+
text_initializer_range: float,
|
452 |
+
text_intermediate_size: int,
|
453 |
+
text_max_position_embeddings: int,
|
454 |
+
text_num_attention_heads: int,
|
455 |
+
text_num_hidden_layers: int,
|
456 |
+
text_type_vocab_size: int,
|
457 |
+
tokenizer = _tokenizer,
|
458 |
+
# vision head width, added this param for ViT-H
|
459 |
+
vision_head_width: int = 64,
|
460 |
+
use_flash_attention: bool = False,
|
461 |
+
openai_clip_path: str = "/group/30042/kunyi/CLIP/clip-vit-large-patch14/",
|
462 |
+
):
|
463 |
+
super().__init__()
|
464 |
+
|
465 |
+
if isinstance(vision_layers, (tuple, list)):
|
466 |
+
vision_heads = vision_width * 32 // vision_head_width
|
467 |
+
self.visual = ModifiedResNet(
|
468 |
+
layers=vision_layers,
|
469 |
+
output_dim=embed_dim,
|
470 |
+
heads=vision_heads,
|
471 |
+
input_resolution=image_resolution,
|
472 |
+
width=vision_width
|
473 |
+
)
|
474 |
+
else:
|
475 |
+
vision_heads = vision_width // vision_head_width
|
476 |
+
self.visual = VisualTransformer(
|
477 |
+
input_resolution=image_resolution,
|
478 |
+
patch_size=vision_patch_size,
|
479 |
+
width=vision_width,
|
480 |
+
layers=vision_layers,
|
481 |
+
heads=vision_heads,
|
482 |
+
output_dim=embed_dim,
|
483 |
+
use_flash_attention=use_flash_attention
|
484 |
+
)
|
485 |
+
|
486 |
+
self.bert_config = BertConfig(
|
487 |
+
vocab_size_or_config_json_file=vocab_size,
|
488 |
+
hidden_size=text_hidden_size,
|
489 |
+
num_hidden_layers=text_num_hidden_layers,
|
490 |
+
num_attention_heads=text_num_attention_heads,
|
491 |
+
intermediate_size=text_intermediate_size,
|
492 |
+
hidden_act=text_hidden_act,
|
493 |
+
hidden_dropout_prob=text_hidden_dropout_prob,
|
494 |
+
attention_probs_dropout_prob=text_attention_probs_dropout_prob,
|
495 |
+
max_position_embeddings=text_max_position_embeddings,
|
496 |
+
type_vocab_size=text_type_vocab_size,
|
497 |
+
initializer_range=text_initializer_range,
|
498 |
+
layer_norm_eps=1e-12,
|
499 |
+
use_flash_attention=use_flash_attention
|
500 |
+
)
|
501 |
+
self.bert = BertModel(self.bert_config)
|
502 |
+
|
503 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
|
504 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
505 |
+
|
506 |
+
self.tokenizer = tokenizer
|
507 |
+
|
508 |
+
print('loading openai clip text encoder')
|
509 |
+
self.openai_clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(openai_clip_path)
|
510 |
+
|
511 |
+
self.initialize_parameters()
|
512 |
+
|
513 |
+
|
514 |
+
def initialize_parameters(self):
|
515 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
516 |
+
|
517 |
+
if isinstance(self.visual, ModifiedResNet):
|
518 |
+
if self.visual.attnpool is not None:
|
519 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
520 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
521 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
522 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
523 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
524 |
+
|
525 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
526 |
+
for name, param in resnet_block.named_parameters():
|
527 |
+
if name.endswith("bn3.weight"):
|
528 |
+
nn.init.zeros_(param)
|
529 |
+
|
530 |
+
if self.text_projection is not None:
|
531 |
+
nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
|
532 |
+
|
533 |
+
@torch.jit.ignore
|
534 |
+
def set_grad_checkpointing(self, enable=True):
|
535 |
+
self.visual.set_grad_checkpointing(enable)
|
536 |
+
self.bert.set_grad_checkpointing(enable)
|
537 |
+
|
538 |
+
@property
|
539 |
+
def dtype(self):
|
540 |
+
return self.visual.conv1.weight.dtype
|
541 |
+
|
542 |
+
def encode_image(self, image, mask_ratio=0):
|
543 |
+
if isinstance(self.visual, ModifiedResNet):
|
544 |
+
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
|
545 |
+
return self.visual(image.type(self.dtype))
|
546 |
+
return self.visual(image.type(self.dtype), mask_ratio)
|
547 |
+
|
548 |
+
def encode_text(self, text):
|
549 |
+
pad_index = self.tokenizer.vocab['[PAD]']
|
550 |
+
attn_mask = text.ne(pad_index).type(self.dtype)
|
551 |
+
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
|
552 |
+
return x[:, 0, :] @ self.text_projection
|
553 |
+
|
554 |
+
def encode_text_ENG(self, text):
|
555 |
+
text_emb = self.openai_clip_text_encoder(text).text_embeds
|
556 |
+
return text_emb
|
557 |
+
|
558 |
+
def forward(self, image, text, is_ENG=False, mask_ratio=0):
|
559 |
+
assert image is not None or text is not None, "text and image cannot both be None!"
|
560 |
+
|
561 |
+
if image is None:
|
562 |
+
if not is_ENG:
|
563 |
+
return self.encode_text(text)
|
564 |
+
else:
|
565 |
+
return self.encode_text_ENG(text)
|
566 |
+
elif text is None:
|
567 |
+
return self.encode_image(image, mask_ratio)
|
568 |
+
image_features = self.encode_image(image, mask_ratio)
|
569 |
+
|
570 |
+
if not is_ENG:
|
571 |
+
text_features = self.encode_text(text)
|
572 |
+
else:
|
573 |
+
text_features = self.encode_text_ENG(text)
|
574 |
+
|
575 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
576 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
577 |
+
|
578 |
+
return image_features, text_features, self.logit_scale.exp()
|
579 |
+
|
580 |
+
def get_similarity(self, image, text):
|
581 |
+
image_features = self.encode_image(image)
|
582 |
+
text_features = self.encode_text(text)
|
583 |
+
|
584 |
+
# normalized features
|
585 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
586 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
587 |
+
|
588 |
+
# cosine similarity as logits
|
589 |
+
logit_scale = self.logit_scale.exp()
|
590 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
591 |
+
logits_per_text = logits_per_image.t()
|
592 |
+
|
593 |
+
# shape = [global_batch_size, global_batch_size]
|
594 |
+
return logits_per_image, logits_per_text
|
595 |
+
|
596 |
+
class CLIP4SD(nn.Module):
|
597 |
+
def __init__(self,
|
598 |
+
embed_dim: int,
|
599 |
+
# vision
|
600 |
+
image_resolution: int,
|
601 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
602 |
+
vision_width: int,
|
603 |
+
vision_patch_size: int,
|
604 |
+
# text
|
605 |
+
vocab_size: int,
|
606 |
+
text_attention_probs_dropout_prob: float,
|
607 |
+
text_hidden_act: str,
|
608 |
+
text_hidden_dropout_prob: float,
|
609 |
+
text_hidden_size: int,
|
610 |
+
text_initializer_range: float,
|
611 |
+
text_intermediate_size: int,
|
612 |
+
text_max_position_embeddings: int,
|
613 |
+
text_num_attention_heads: int,
|
614 |
+
text_num_hidden_layers: int,
|
615 |
+
text_type_vocab_size: int,
|
616 |
+
tokenizer = _tokenizer,
|
617 |
+
# vision head width, added this param for ViT-H
|
618 |
+
vision_head_width: int = 64,
|
619 |
+
use_flash_attention: bool = False,
|
620 |
+
):
|
621 |
+
super().__init__()
|
622 |
+
|
623 |
+
if isinstance(vision_layers, (tuple, list)):
|
624 |
+
vision_heads = vision_width * 32 // vision_head_width
|
625 |
+
self.visual = ModifiedResNet(
|
626 |
+
layers=vision_layers,
|
627 |
+
output_dim=embed_dim,
|
628 |
+
heads=vision_heads,
|
629 |
+
input_resolution=image_resolution,
|
630 |
+
width=vision_width
|
631 |
+
)
|
632 |
+
else:
|
633 |
+
vision_heads = vision_width // vision_head_width
|
634 |
+
self.visual = VisualTransformer(
|
635 |
+
input_resolution=image_resolution,
|
636 |
+
patch_size=vision_patch_size,
|
637 |
+
width=vision_width,
|
638 |
+
layers=vision_layers,
|
639 |
+
heads=vision_heads,
|
640 |
+
output_dim=embed_dim,
|
641 |
+
use_flash_attention=use_flash_attention
|
642 |
+
)
|
643 |
+
|
644 |
+
self.bert_config = BertConfig(
|
645 |
+
vocab_size_or_config_json_file=vocab_size,
|
646 |
+
hidden_size=text_hidden_size,
|
647 |
+
num_hidden_layers=text_num_hidden_layers,
|
648 |
+
num_attention_heads=text_num_attention_heads,
|
649 |
+
intermediate_size=text_intermediate_size,
|
650 |
+
hidden_act=text_hidden_act,
|
651 |
+
hidden_dropout_prob=text_hidden_dropout_prob,
|
652 |
+
attention_probs_dropout_prob=text_attention_probs_dropout_prob,
|
653 |
+
max_position_embeddings=text_max_position_embeddings,
|
654 |
+
type_vocab_size=text_type_vocab_size,
|
655 |
+
initializer_range=text_initializer_range,
|
656 |
+
layer_norm_eps=1e-12,
|
657 |
+
use_flash_attention=use_flash_attention
|
658 |
+
)
|
659 |
+
self.bert = BertModel(self.bert_config)
|
660 |
+
|
661 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
|
662 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
663 |
+
|
664 |
+
self.tokenizer = tokenizer
|
665 |
+
self.ln_final = LayerNorm(text_hidden_size)
|
666 |
+
|
667 |
+
self.initialize_parameters()
|
668 |
+
|
669 |
+
def initialize_parameters(self):
|
670 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
671 |
+
|
672 |
+
if isinstance(self.visual, ModifiedResNet):
|
673 |
+
if self.visual.attnpool is not None:
|
674 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
675 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
676 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
677 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
678 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
679 |
+
|
680 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
681 |
+
for name, param in resnet_block.named_parameters():
|
682 |
+
if name.endswith("bn3.weight"):
|
683 |
+
nn.init.zeros_(param)
|
684 |
+
|
685 |
+
if self.text_projection is not None:
|
686 |
+
nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
|
687 |
+
|
688 |
+
@torch.jit.ignore
|
689 |
+
def set_grad_checkpointing(self, enable=True):
|
690 |
+
self.visual.set_grad_checkpointing(enable)
|
691 |
+
self.bert.set_grad_checkpointing(enable)
|
692 |
+
|
693 |
+
@property
|
694 |
+
def dtype(self):
|
695 |
+
return self.visual.conv1.weight.dtype
|
696 |
+
|
697 |
+
def encode_image(self, image, mask_ratio=0):
|
698 |
+
if isinstance(self.visual, ModifiedResNet):
|
699 |
+
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
|
700 |
+
return self.visual(image.type(self.dtype))
|
701 |
+
return self.visual(image.type(self.dtype), mask_ratio)
|
702 |
+
|
703 |
+
# def encode_text(self, text):
|
704 |
+
# pad_index = self.tokenizer.vocab['[PAD]']
|
705 |
+
# attn_mask = text.ne(pad_index).type(self.dtype)
|
706 |
+
# x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
|
707 |
+
# return x[:, 0, :] @ self.text_projection
|
708 |
+
def encode_text(self, text):
|
709 |
+
pad_index = self.tokenizer.vocab['[PAD]']
|
710 |
+
attn_mask = text.ne(pad_index).type(self.dtype)
|
711 |
+
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
|
712 |
+
x = self.ln_final(x).type(self.dtype)
|
713 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
714 |
+
return x
|
715 |
+
|
716 |
+
def forward(self, image, text, mask_ratio=0):
|
717 |
+
assert image is not None or text is not None, "text and image cannot both be None!"
|
718 |
+
|
719 |
+
if image is None:
|
720 |
+
return self.encode_text(text)
|
721 |
+
elif text is None:
|
722 |
+
return self.encode_image(image)
|
723 |
+
image_features = self.encode_image(image, mask_ratio)
|
724 |
+
text_features = self.encode_text(text)
|
725 |
+
|
726 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
727 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
728 |
+
|
729 |
+
return image_features, text_features, self.logit_scale.exp()
|
730 |
+
|
731 |
+
def get_similarity(self, image, text):
|
732 |
+
image_features = self.encode_image(image)
|
733 |
+
text_features = self.encode_text(text)
|
734 |
+
|
735 |
+
# normalized features
|
736 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
737 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
738 |
+
|
739 |
+
# cosine similarity as logits
|
740 |
+
logit_scale = self.logit_scale.exp()
|
741 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
742 |
+
logits_per_text = logits_per_image.t()
|
743 |
+
|
744 |
+
# shape = [global_batch_size, global_batch_size]
|
745 |
+
return logits_per_image, logits_per_text
|
746 |
+
|
747 |
+
def convert_models_to_fp32(model):
|
748 |
+
for p in model.parameters():
|
749 |
+
p.data = p.data.float()
|
750 |
+
if p.grad:
|
751 |
+
p.grad.data = p.grad.data.float()
|
752 |
+
|
753 |
+
|
754 |
+
def convert_weights(model: nn.Module):
|
755 |
+
"""Convert applicable model parameters to fp16"""
|
756 |
+
|
757 |
+
def _convert_weights_to_fp16(l):
|
758 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
759 |
+
l.weight.data = l.weight.data.half()
|
760 |
+
if l.bias is not None:
|
761 |
+
l.bias.data = l.bias.data.half()
|
762 |
+
|
763 |
+
if isinstance(l, nn.MultiheadAttention):
|
764 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
765 |
+
tensor = getattr(l, attr)
|
766 |
+
if tensor is not None:
|
767 |
+
tensor.data = tensor.data.half()
|
768 |
+
|
769 |
+
if isinstance(l, BertModel):
|
770 |
+
l.to(torch.half)
|
771 |
+
|
772 |
+
for name in ["text_projection", "proj"]:
|
773 |
+
try:
|
774 |
+
if hasattr(l, name):
|
775 |
+
attr = getattr(l, name)
|
776 |
+
if attr is not None:
|
777 |
+
attr.data = attr.data.half()
|
778 |
+
except:
|
779 |
+
print('name', name)
|
780 |
+
|
781 |
+
model.apply(_convert_weights_to_fp16)
|
782 |
+
|
783 |
+
|
784 |
+
def restore_model(model, clip_state_dict: dict, bert_state_dict: dict, use_flash_attention: bool):
|
785 |
+
merged_state_dict = {}
|
786 |
+
|
787 |
+
# use clip_state_dict to initialize the image encoder & logit scale
|
788 |
+
if clip_state_dict is not None:
|
789 |
+
for k, v in clip_state_dict.items():
|
790 |
+
if k.startswith("visual") or k == "logit_scale":
|
791 |
+
merged_state_dict[k] = v
|
792 |
+
|
793 |
+
# use bert_state_dict to initialize the text encoder
|
794 |
+
if bert_state_dict is not None:
|
795 |
+
for k, v in bert_state_dict.items():
|
796 |
+
if k.startswith("bert") and "bert.pooler" not in k:
|
797 |
+
merged_state_dict[k] = v
|
798 |
+
|
799 |
+
# adapt flash attention
|
800 |
+
if use_flash_attention:
|
801 |
+
merged_state_dict = convert_state_dict(merged_state_dict)
|
802 |
+
|
803 |
+
convert_weights(model)
|
804 |
+
resize_pos_embed(merged_state_dict, model)
|
805 |
+
model.load_state_dict(merged_state_dict, strict=False)
|
806 |
+
return model.eval()
|
807 |
+
|
808 |
+
|
809 |
+
def convert_state_dict(state_dict):
|
810 |
+
"""Adapt to Flash Attention"""
|
811 |
+
if not state_dict:
|
812 |
+
return state_dict
|
813 |
+
|
814 |
+
prefix = 'module.' if list(state_dict.keys())[0].startswith('module') else ''
|
815 |
+
|
816 |
+
if f'{prefix}visual.transformer.resblocks.0.attn.in_proj_weight' in state_dict:
|
817 |
+
for k in list(state_dict.keys()):
|
818 |
+
if 'attn.in_proj_weight' in k:
|
819 |
+
state_dict[k.replace('attn.in_proj_weight', 'attn.Wqkv.weight')] = state_dict.pop(k)
|
820 |
+
elif 'attn.in_proj_bias' in k:
|
821 |
+
state_dict[k.replace('attn.in_proj_bias', 'attn.Wqkv.bias')] = state_dict.pop(k)
|
822 |
+
elif f'{prefix}visual.transformer.resblocks.0.attn.Wqkv.weight' in state_dict:
|
823 |
+
for k in list(state_dict.keys()):
|
824 |
+
if 'attn.Wqkv.weight' in k:
|
825 |
+
state_dict[k.replace('attn.Wqkv.weight', 'attn.in_proj_weight')] = state_dict.pop(k)
|
826 |
+
elif 'attn.Wqkv.bias' in k:
|
827 |
+
state_dict[k.replace('attn.Wqkv.bias', 'attn.in_proj_bias')] = state_dict.pop(k)
|
828 |
+
|
829 |
+
if f'{prefix}bert.encoder.layer.0.attention.self.query.weight' in state_dict:
|
830 |
+
i = 0
|
831 |
+
while f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight' in state_dict:
|
832 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'] = torch.cat(
|
833 |
+
(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'),
|
834 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'),
|
835 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'))
|
836 |
+
)
|
837 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'] = torch.cat(
|
838 |
+
(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'),
|
839 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'),
|
840 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'))
|
841 |
+
)
|
842 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight'] = \
|
843 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight')
|
844 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.bias'] = \
|
845 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias')
|
846 |
+
i += 1
|
847 |
+
elif f'{prefix}bert.encoder.layer.0.attention.self.Wqkv.weight' in state_dict:
|
848 |
+
i = 0
|
849 |
+
while f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight' in state_dict:
|
850 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'], \
|
851 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'], \
|
852 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'] = \
|
853 |
+
torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'), chunks=3)
|
854 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'], \
|
855 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'], \
|
856 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'] = \
|
857 |
+
torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'), chunks=3)
|
858 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight'] = \
|
859 |
+
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight')
|
860 |
+
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias'] = \
|
861 |
+
state_dict.pop(f'module.bert.encoder.layer.{i}.attention.self.out_proj.bias')
|
862 |
+
i += 1
|
863 |
+
|
864 |
+
return state_dict
|
865 |
+
|
866 |
+
|
867 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1, prefix=""):
|
868 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
869 |
+
old_pos_embed = state_dict.get(prefix + 'visual.positional_embedding', None)
|
870 |
+
model = model.module if hasattr(model, 'module') else model
|
871 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
872 |
+
return
|
873 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
874 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
875 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
876 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
877 |
+
return
|
878 |
+
|
879 |
+
if extra_tokens:
|
880 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
881 |
+
else:
|
882 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
883 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
884 |
+
|
885 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
886 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
887 |
+
pos_emb_img = F.interpolate(
|
888 |
+
pos_emb_img,
|
889 |
+
size=grid_size,
|
890 |
+
mode=interpolation,
|
891 |
+
align_corners=True,
|
892 |
+
)
|
893 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
894 |
+
if pos_emb_tok is not None:
|
895 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
896 |
+
else:
|
897 |
+
new_pos_embed = pos_emb_img
|
898 |
+
state_dict[prefix + 'visual.positional_embedding'] = new_pos_embed
|
899 |
+
|
900 |
+
|
901 |
+
# From PyTorch internals
|
902 |
+
def _ntuple(n):
|
903 |
+
def parse(x):
|
904 |
+
if isinstance(x, collections.abc.Iterable):
|
905 |
+
return x
|
906 |
+
return tuple(repeat(x, n))
|
907 |
+
return parse
|
908 |
+
|
909 |
+
|
910 |
+
to_1tuple = _ntuple(1)
|
911 |
+
to_2tuple = _ntuple(2)
|
912 |
+
to_3tuple = _ntuple(3)
|
913 |
+
to_4tuple = _ntuple(4)
|
914 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
clip/model_configs/RBT3-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 768,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 3072,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 12,
|
11 |
+
"text_num_hidden_layers": 3,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": "[3,4,6,3]",
|
5 |
+
"vision_width": 64,
|
6 |
+
"vision_patch_size": null
|
7 |
+
}
|
clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 768,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 3072,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 12,
|
11 |
+
"text_num_hidden_layers": 12,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
clip/model_configs/RoBERTa-wwm-ext-large-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 1024,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 4096,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 16,
|
11 |
+
"text_num_hidden_layers": 24,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 16
|
7 |
+
}
|
clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 32
|
7 |
+
}
|
clip/model_configs/ViT-H-14.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 32,
|
5 |
+
"vision_width": 1280,
|
6 |
+
"vision_head_width": 80,
|
7 |
+
"vision_patch_size": 14
|
8 |
+
}
|
clip/model_configs/ViT-L-14-336.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 336,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14
|
7 |
+
}
|
clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14
|
7 |
+
}
|
clip/modeling_bert.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model. """
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
19 |
+
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
from io import open
|
26 |
+
|
27 |
+
import torch
|
28 |
+
from torch import nn
|
29 |
+
from torch.utils.checkpoint import checkpoint
|
30 |
+
|
31 |
+
import importlib.util
|
32 |
+
if importlib.util.find_spec('flash_attn'):
|
33 |
+
FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA
|
34 |
+
|
35 |
+
from .configuration_bert import BertConfig
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
def gelu(x):
|
40 |
+
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
41 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
42 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
43 |
+
Also see https://arxiv.org/abs/1606.08415
|
44 |
+
"""
|
45 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
46 |
+
|
47 |
+
def gelu_new(x):
|
48 |
+
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
49 |
+
Also see https://arxiv.org/abs/1606.08415
|
50 |
+
"""
|
51 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
52 |
+
|
53 |
+
def swish(x):
|
54 |
+
return x * torch.sigmoid(x)
|
55 |
+
|
56 |
+
|
57 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
|
58 |
+
|
59 |
+
|
60 |
+
BertLayerNorm = torch.nn.LayerNorm
|
61 |
+
|
62 |
+
class BertEmbeddings(nn.Module):
|
63 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
64 |
+
"""
|
65 |
+
def __init__(self, config):
|
66 |
+
super(BertEmbeddings, self).__init__()
|
67 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
68 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
69 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
70 |
+
|
71 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
72 |
+
# any TensorFlow checkpoint file
|
73 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
74 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
75 |
+
|
76 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
77 |
+
seq_length = input_ids.size(1)
|
78 |
+
if position_ids is None:
|
79 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
80 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
81 |
+
if token_type_ids is None:
|
82 |
+
token_type_ids = torch.zeros_like(input_ids)
|
83 |
+
|
84 |
+
words_embeddings = self.word_embeddings(input_ids)
|
85 |
+
position_embeddings = self.position_embeddings(position_ids)
|
86 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
87 |
+
|
88 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
89 |
+
embeddings = self.LayerNorm(embeddings)
|
90 |
+
embeddings = self.dropout(embeddings)
|
91 |
+
return embeddings
|
92 |
+
|
93 |
+
|
94 |
+
class BertSelfAttention(nn.Module):
|
95 |
+
def __init__(self, config):
|
96 |
+
super(BertSelfAttention, self).__init__()
|
97 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
98 |
+
raise ValueError(
|
99 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
100 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
101 |
+
self.output_attentions = config.output_attentions
|
102 |
+
|
103 |
+
self.num_attention_heads = config.num_attention_heads
|
104 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
105 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
106 |
+
|
107 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
108 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
109 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
110 |
+
|
111 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
112 |
+
|
113 |
+
def transpose_for_scores(self, x):
|
114 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
115 |
+
x = x.view(*new_x_shape)
|
116 |
+
return x.permute(0, 2, 1, 3)
|
117 |
+
|
118 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
119 |
+
mixed_query_layer = self.query(hidden_states)
|
120 |
+
mixed_key_layer = self.key(hidden_states)
|
121 |
+
mixed_value_layer = self.value(hidden_states)
|
122 |
+
|
123 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
124 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
125 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
126 |
+
|
127 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
128 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
129 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
130 |
+
if attention_mask is not None:
|
131 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
132 |
+
attention_scores = attention_scores + attention_mask
|
133 |
+
|
134 |
+
# Normalize the attention scores to probabilities.
|
135 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
136 |
+
|
137 |
+
# This is actually dropping out entire tokens to attend to, which might
|
138 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
139 |
+
attention_probs = self.dropout(attention_probs)
|
140 |
+
|
141 |
+
# Mask heads if we want to
|
142 |
+
if head_mask is not None:
|
143 |
+
attention_probs = attention_probs * head_mask
|
144 |
+
|
145 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
146 |
+
|
147 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
148 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
149 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
150 |
+
|
151 |
+
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
152 |
+
return outputs
|
153 |
+
|
154 |
+
|
155 |
+
class BertSelfOutput(nn.Module):
|
156 |
+
def __init__(self, config):
|
157 |
+
super(BertSelfOutput, self).__init__()
|
158 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
159 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
160 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
161 |
+
|
162 |
+
def forward(self, hidden_states, input_tensor):
|
163 |
+
hidden_states = self.dense(hidden_states)
|
164 |
+
hidden_states = self.dropout(hidden_states)
|
165 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
166 |
+
return hidden_states
|
167 |
+
|
168 |
+
|
169 |
+
class BertAttention(nn.Module):
|
170 |
+
def __init__(self, config):
|
171 |
+
super(BertAttention, self).__init__()
|
172 |
+
self.self = BertSelfAttention(config) if not config.use_flash_attention else FlashMHA(config.hidden_size, config.num_attention_heads)
|
173 |
+
self.output = BertSelfOutput(config) if not config.use_flash_attention else BertSelfOutputForFlashAttention(config)
|
174 |
+
self.pruned_heads = set()
|
175 |
+
self.config = config
|
176 |
+
|
177 |
+
def forward(self, input_tensor, attention_mask=None, head_mask=None):
|
178 |
+
if not self.config.use_flash_attention:
|
179 |
+
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
180 |
+
else:
|
181 |
+
key_padding_mask = self.get_key_padding_mask(attention_mask)
|
182 |
+
self_outputs = self.self(input_tensor, key_padding_mask=key_padding_mask)
|
183 |
+
attention_output = self.output(self_outputs[0], input_tensor)
|
184 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
185 |
+
return outputs
|
186 |
+
|
187 |
+
def get_key_padding_mask(self, attention_mask):
|
188 |
+
# key_padding_mask: bool tensor of shape (batch, seqlen)
|
189 |
+
return attention_mask.squeeze(1).squeeze(1) == 0
|
190 |
+
|
191 |
+
|
192 |
+
class BertIntermediate(nn.Module):
|
193 |
+
def __init__(self, config):
|
194 |
+
super(BertIntermediate, self).__init__()
|
195 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
196 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
197 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
198 |
+
else:
|
199 |
+
self.intermediate_act_fn = config.hidden_act
|
200 |
+
|
201 |
+
def forward(self, hidden_states):
|
202 |
+
hidden_states = self.dense(hidden_states)
|
203 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
204 |
+
return hidden_states
|
205 |
+
|
206 |
+
|
207 |
+
class BertOutput(nn.Module):
|
208 |
+
def __init__(self, config):
|
209 |
+
super(BertOutput, self).__init__()
|
210 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
211 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
212 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
213 |
+
|
214 |
+
def forward(self, hidden_states, input_tensor):
|
215 |
+
hidden_states = self.dense(hidden_states)
|
216 |
+
hidden_states = self.dropout(hidden_states)
|
217 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
218 |
+
return hidden_states
|
219 |
+
|
220 |
+
|
221 |
+
class BertSelfOutputForFlashAttention(nn.Module): # remove linear layer
|
222 |
+
def __init__(self, config):
|
223 |
+
super(BertSelfOutputForFlashAttention, self).__init__()
|
224 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
225 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
226 |
+
|
227 |
+
def forward(self, hidden_states, input_tensor):
|
228 |
+
hidden_states = self.dropout(hidden_states)
|
229 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
230 |
+
return hidden_states
|
231 |
+
|
232 |
+
|
233 |
+
class BertLayer(nn.Module):
|
234 |
+
def __init__(self, config):
|
235 |
+
super(BertLayer, self).__init__()
|
236 |
+
self.attention = BertAttention(config)
|
237 |
+
self.intermediate = BertIntermediate(config)
|
238 |
+
self.output = BertOutput(config)
|
239 |
+
|
240 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
241 |
+
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
242 |
+
attention_output = attention_outputs[0]
|
243 |
+
intermediate_output = self.intermediate(attention_output)
|
244 |
+
layer_output = self.output(intermediate_output, attention_output)
|
245 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
246 |
+
if len(outputs) == 1:
|
247 |
+
return outputs[0]
|
248 |
+
return outputs
|
249 |
+
|
250 |
+
|
251 |
+
class BertEncoder(nn.Module):
|
252 |
+
def __init__(self, config):
|
253 |
+
super(BertEncoder, self).__init__()
|
254 |
+
self.output_attentions = config.output_attentions
|
255 |
+
self.output_hidden_states = config.output_hidden_states
|
256 |
+
self.grad_checkpointing = False
|
257 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
258 |
+
|
259 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
260 |
+
all_hidden_states = ()
|
261 |
+
all_attentions = ()
|
262 |
+
for i, layer_module in enumerate(self.layer):
|
263 |
+
if self.output_hidden_states:
|
264 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
265 |
+
|
266 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
267 |
+
layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i])
|
268 |
+
else:
|
269 |
+
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
270 |
+
if not isinstance(layer_outputs, tuple):
|
271 |
+
layer_outputs = (layer_outputs, )
|
272 |
+
hidden_states = layer_outputs[0]
|
273 |
+
|
274 |
+
if self.output_attentions:
|
275 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
276 |
+
|
277 |
+
# Add last layer
|
278 |
+
if self.output_hidden_states:
|
279 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
280 |
+
|
281 |
+
outputs = (hidden_states,)
|
282 |
+
if self.output_hidden_states:
|
283 |
+
outputs = outputs + (all_hidden_states,)
|
284 |
+
if self.output_attentions:
|
285 |
+
outputs = outputs + (all_attentions,)
|
286 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
287 |
+
|
288 |
+
|
289 |
+
class BertPooler(nn.Module):
|
290 |
+
def __init__(self, config):
|
291 |
+
super(BertPooler, self).__init__()
|
292 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
293 |
+
self.activation = nn.Tanh()
|
294 |
+
|
295 |
+
def forward(self, hidden_states):
|
296 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
297 |
+
# to the first token.
|
298 |
+
first_token_tensor = hidden_states[:, 0]
|
299 |
+
pooled_output = self.dense(first_token_tensor)
|
300 |
+
pooled_output = self.activation(pooled_output)
|
301 |
+
return pooled_output
|
302 |
+
|
303 |
+
|
304 |
+
class BertPredictionHeadTransform(nn.Module):
|
305 |
+
def __init__(self, config):
|
306 |
+
super(BertPredictionHeadTransform, self).__init__()
|
307 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
308 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
309 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
310 |
+
else:
|
311 |
+
self.transform_act_fn = config.hidden_act
|
312 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
313 |
+
|
314 |
+
def forward(self, hidden_states):
|
315 |
+
hidden_states = self.dense(hidden_states)
|
316 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
317 |
+
hidden_states = self.LayerNorm(hidden_states)
|
318 |
+
return hidden_states
|
319 |
+
|
320 |
+
|
321 |
+
class BertLMPredictionHead(nn.Module):
|
322 |
+
def __init__(self, config):
|
323 |
+
super(BertLMPredictionHead, self).__init__()
|
324 |
+
self.transform = BertPredictionHeadTransform(config)
|
325 |
+
|
326 |
+
# The output weights are the same as the input embeddings, but there is
|
327 |
+
# an output-only bias for each token.
|
328 |
+
self.decoder = nn.Linear(config.hidden_size,
|
329 |
+
config.vocab_size,
|
330 |
+
bias=False)
|
331 |
+
|
332 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
333 |
+
|
334 |
+
def forward(self, hidden_states):
|
335 |
+
hidden_states = self.transform(hidden_states)
|
336 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
337 |
+
return hidden_states
|
338 |
+
|
339 |
+
|
340 |
+
class BertOnlyMLMHead(nn.Module):
|
341 |
+
def __init__(self, config):
|
342 |
+
super(BertOnlyMLMHead, self).__init__()
|
343 |
+
self.predictions = BertLMPredictionHead(config)
|
344 |
+
|
345 |
+
def forward(self, sequence_output):
|
346 |
+
prediction_scores = self.predictions(sequence_output)
|
347 |
+
return prediction_scores
|
348 |
+
|
349 |
+
|
350 |
+
class BertOnlyNSPHead(nn.Module):
|
351 |
+
def __init__(self, config):
|
352 |
+
super(BertOnlyNSPHead, self).__init__()
|
353 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
354 |
+
|
355 |
+
def forward(self, pooled_output):
|
356 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
357 |
+
return seq_relationship_score
|
358 |
+
|
359 |
+
|
360 |
+
class BertPreTrainingHeads(nn.Module):
|
361 |
+
def __init__(self, config):
|
362 |
+
super(BertPreTrainingHeads, self).__init__()
|
363 |
+
self.predictions = BertLMPredictionHead(config)
|
364 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
365 |
+
|
366 |
+
def forward(self, sequence_output, pooled_output):
|
367 |
+
prediction_scores = self.predictions(sequence_output)
|
368 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
369 |
+
return prediction_scores, seq_relationship_score
|
370 |
+
|
371 |
+
|
372 |
+
class BertPreTrainedModel(nn.Module):
|
373 |
+
config_class = BertConfig
|
374 |
+
base_model_prefix = "bert"
|
375 |
+
|
376 |
+
def __init__(self, config):
|
377 |
+
super(BertPreTrainedModel, self).__init__()
|
378 |
+
self.config = config
|
379 |
+
|
380 |
+
def _init_weights(self, module):
|
381 |
+
""" Initialize the weights """
|
382 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
383 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
384 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
385 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
386 |
+
elif isinstance(module, BertLayerNorm):
|
387 |
+
module.bias.data.zero_()
|
388 |
+
module.weight.data.fill_(1.0)
|
389 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
390 |
+
module.bias.data.zero_()
|
391 |
+
|
392 |
+
|
393 |
+
class BertModel(BertPreTrainedModel):
|
394 |
+
r"""
|
395 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
396 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
397 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
398 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
399 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
400 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
401 |
+
layer weights are trained from the next sentence prediction (classification)
|
402 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
403 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
404 |
+
the sequence of hidden-states for the whole input sequence.
|
405 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
406 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
407 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
408 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
409 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
410 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
411 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
412 |
+
|
413 |
+
Examples::
|
414 |
+
|
415 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
416 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
417 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
418 |
+
outputs = model(input_ids)
|
419 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
420 |
+
|
421 |
+
"""
|
422 |
+
def __init__(self, config):
|
423 |
+
super(BertModel, self).__init__(config)
|
424 |
+
|
425 |
+
self.embeddings = BertEmbeddings(config)
|
426 |
+
self.encoder = BertEncoder(config)
|
427 |
+
# self.pooler = BertPooler(config)
|
428 |
+
|
429 |
+
self.apply(self._init_weights)
|
430 |
+
|
431 |
+
@torch.jit.ignore
|
432 |
+
def set_grad_checkpointing(self, enable=True):
|
433 |
+
if enable:
|
434 |
+
assert not self.config.output_attentions, \
|
435 |
+
"Grad checkpointing is currently conflict with output_attentions for BertEncoder, \
|
436 |
+
please set it to False in BertConfig"
|
437 |
+
self.encoder.grad_checkpointing = enable
|
438 |
+
|
439 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
440 |
+
if attention_mask is None:
|
441 |
+
attention_mask = torch.ones_like(input_ids)
|
442 |
+
if token_type_ids is None:
|
443 |
+
token_type_ids = torch.zeros_like(input_ids)
|
444 |
+
|
445 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
446 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
447 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
448 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
449 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
450 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
451 |
+
|
452 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
453 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
454 |
+
# positions we want to attend and -10000.0 for masked positions.
|
455 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
456 |
+
# effectively the same as removing these entirely.
|
457 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
458 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
459 |
+
|
460 |
+
# Prepare head mask if needed
|
461 |
+
# 1.0 in head_mask indicate we keep the head
|
462 |
+
# attention_probs has shape bsz x n_heads x N x N
|
463 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
464 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
465 |
+
if head_mask is not None:
|
466 |
+
if head_mask.dim() == 1:
|
467 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
468 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
469 |
+
elif head_mask.dim() == 2:
|
470 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
471 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
472 |
+
else:
|
473 |
+
head_mask = [None] * self.config.num_hidden_layers
|
474 |
+
|
475 |
+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
476 |
+
encoder_outputs = self.encoder(embedding_output,
|
477 |
+
extended_attention_mask,
|
478 |
+
head_mask=head_mask)
|
479 |
+
sequence_output = encoder_outputs[0]
|
480 |
+
# pooled_output = self.pooler(sequence_output)
|
481 |
+
pooled_output = None
|
482 |
+
|
483 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
484 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
clip/utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code modified from https://github.com/openai/CLIP
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Union, List
|
7 |
+
import urllib
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from clip import _tokenizer
|
14 |
+
from clip.model import convert_weights, CLIP, restore_model
|
15 |
+
|
16 |
+
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
|
17 |
+
|
18 |
+
_MODELS = {
|
19 |
+
"ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt",
|
20 |
+
"ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt",
|
21 |
+
"RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt",
|
22 |
+
}
|
23 |
+
_MODEL_INFO = {
|
24 |
+
"ViT-B-16": {
|
25 |
+
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
|
26 |
+
"input_resolution": 224
|
27 |
+
},
|
28 |
+
"ViT-L-14": {
|
29 |
+
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
|
30 |
+
"input_resolution": 224
|
31 |
+
},
|
32 |
+
"RN50": {
|
33 |
+
"struct": "RN50@RBT3-chinese",
|
34 |
+
"input_resolution": 224
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def _download(url: str, root: str):
|
40 |
+
os.makedirs(root, exist_ok=True)
|
41 |
+
filename = os.path.basename(url)
|
42 |
+
|
43 |
+
download_target = os.path.join(root, filename)
|
44 |
+
|
45 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
46 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
47 |
+
|
48 |
+
if os.path.isfile(download_target):
|
49 |
+
return download_target
|
50 |
+
|
51 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
52 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
53 |
+
unit_divisor=1024) as loop:
|
54 |
+
while True:
|
55 |
+
buffer = source.read(8192)
|
56 |
+
if not buffer:
|
57 |
+
break
|
58 |
+
|
59 |
+
output.write(buffer)
|
60 |
+
loop.update(len(buffer))
|
61 |
+
|
62 |
+
return download_target
|
63 |
+
|
64 |
+
|
65 |
+
def _convert_image_to_rgb(image):
|
66 |
+
return image.convert("RGB")
|
67 |
+
|
68 |
+
|
69 |
+
def available_models() -> List[str]:
|
70 |
+
"""Returns the names of available CLIP models"""
|
71 |
+
return list(_MODELS.keys())
|
72 |
+
|
73 |
+
|
74 |
+
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
75 |
+
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
|
76 |
+
if name in _MODELS:
|
77 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
78 |
+
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
|
79 |
+
elif os.path.isfile(name):
|
80 |
+
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
|
81 |
+
model_path = name
|
82 |
+
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
|
83 |
+
else:
|
84 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
85 |
+
|
86 |
+
with open(model_path, 'rb') as opened_file:
|
87 |
+
# loading saved checkpoint
|
88 |
+
checkpoint = torch.load(opened_file, map_location="cpu")
|
89 |
+
|
90 |
+
model = create_model(model_name, checkpoint)
|
91 |
+
if str(device) == "cpu":
|
92 |
+
model.float()
|
93 |
+
else:
|
94 |
+
model.to(device)
|
95 |
+
return model, image_transform(model_input_resolution)
|
96 |
+
|
97 |
+
|
98 |
+
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
|
99 |
+
bert_path=None, use_flash_attention=False):
|
100 |
+
"""Load CLIP and BERT model weights
|
101 |
+
"""
|
102 |
+
|
103 |
+
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
|
104 |
+
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
|
105 |
+
|
106 |
+
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)
|
107 |
+
|
108 |
+
if str(device) == "cpu":
|
109 |
+
model.float()
|
110 |
+
return model
|
111 |
+
|
112 |
+
|
113 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor:
|
114 |
+
"""
|
115 |
+
Returns the tokenized representation of given input string(s)
|
116 |
+
Parameters
|
117 |
+
----------
|
118 |
+
texts : Union[str, List[str]]
|
119 |
+
An input string or a list of input strings to tokenize
|
120 |
+
context_length : int
|
121 |
+
The context length to use; all baseline models use 52 as the context length
|
122 |
+
Returns
|
123 |
+
-------
|
124 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
125 |
+
"""
|
126 |
+
if isinstance(texts, str):
|
127 |
+
texts = [texts]
|
128 |
+
|
129 |
+
all_tokens = []
|
130 |
+
for text in texts:
|
131 |
+
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
|
132 |
+
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
|
133 |
+
|
134 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
135 |
+
|
136 |
+
for i, tokens in enumerate(all_tokens):
|
137 |
+
assert len(tokens) <= context_length
|
138 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
139 |
+
|
140 |
+
return result
|
141 |
+
|
142 |
+
|
143 |
+
def _convert_to_rgb(image):
|
144 |
+
return image.convert('RGB')
|
145 |
+
|
146 |
+
|
147 |
+
def image_transform(image_size=224):
|
148 |
+
transform = Compose([
|
149 |
+
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
150 |
+
_convert_to_rgb,
|
151 |
+
ToTensor(),
|
152 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
153 |
+
])
|
154 |
+
return transform
|
155 |
+
|
156 |
+
|
157 |
+
def create_model(model_name, checkpoint=None):
|
158 |
+
vision_model, text_model = model_name.split('@')
|
159 |
+
# Initialize the model.
|
160 |
+
vision_model_config_file = Path(
|
161 |
+
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
|
162 |
+
print('Loading vision model config from', vision_model_config_file)
|
163 |
+
assert os.path.exists(vision_model_config_file)
|
164 |
+
|
165 |
+
text_model_config_file = Path(
|
166 |
+
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
|
167 |
+
print('Loading text model config from', text_model_config_file)
|
168 |
+
assert os.path.exists(text_model_config_file)
|
169 |
+
|
170 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
171 |
+
model_info = json.load(fv)
|
172 |
+
for k, v in json.load(ft).items():
|
173 |
+
model_info[k] = v
|
174 |
+
if isinstance(model_info['vision_layers'], str):
|
175 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
176 |
+
print('Model info', model_info)
|
177 |
+
model = CLIP(**model_info)
|
178 |
+
convert_weights(model)
|
179 |
+
if checkpoint:
|
180 |
+
sd = checkpoint["state_dict"]
|
181 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
182 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
183 |
+
model.load_state_dict(sd)
|
184 |
+
return model
|
clip/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
eval/cvinw_zeroshot_templates.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script provides templates for manual prompting for zero-shot image classification.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
openai_templates = [
|
7 |
+
lambda c: f"{c}的照片",
|
8 |
+
lambda c: f"质量差的{c}的照片",
|
9 |
+
lambda c: f"许多{c}的照片",
|
10 |
+
lambda c: f"{c}的雕塑",
|
11 |
+
lambda c: f"难以看到{c}的照片",
|
12 |
+
lambda c: f"{c}的低分辨率照片",
|
13 |
+
lambda c: f"{c}的渲染",
|
14 |
+
lambda c: f"涂鸦{c}",
|
15 |
+
lambda c: f"{c}的糟糕照片",
|
16 |
+
lambda c: f"{c}的裁剪照片",
|
17 |
+
lambda c: f"{c}的纹身",
|
18 |
+
lambda c: f"{c}的刺绣照片",
|
19 |
+
lambda c: f"很难看到{c}的照片",
|
20 |
+
lambda c: f"{c}的明亮照片",
|
21 |
+
lambda c: f"一张干净的{c}的照片",
|
22 |
+
lambda c: f"一张包含{c}的照片",
|
23 |
+
lambda c: f"{c}的深色照片",
|
24 |
+
lambda c: f"{c}的手绘画",
|
25 |
+
lambda c: f"我的{c}的照片",
|
26 |
+
lambda c: f"不自然的{c}的照片",
|
27 |
+
lambda c: f"一张酷的{c}的照片",
|
28 |
+
lambda c: f"{c}的特写照片",
|
29 |
+
lambda c: f"{c}的黑白照片",
|
30 |
+
lambda c: f"一幅{c}的画",
|
31 |
+
lambda c: f"一幅{c}的绘画",
|
32 |
+
lambda c: f"一张{c}的像素照片",
|
33 |
+
lambda c: f"{c}的雕像",
|
34 |
+
lambda c: f"一张{c}的明亮照片",
|
35 |
+
lambda c: f"{c}的裁剪照片",
|
36 |
+
lambda c: f"人造的{c}的照片",
|
37 |
+
lambda c: f"一张关于{c}的照片",
|
38 |
+
lambda c: f"损坏的{c}的jpeg照片",
|
39 |
+
lambda c: f"{c}的模糊照片",
|
40 |
+
lambda c: f"{c}的相片",
|
41 |
+
lambda c: f"一张{c}的好照片",
|
42 |
+
lambda c: f"{c}的渲染照",
|
43 |
+
lambda c: f"视频游戏中的{c}",
|
44 |
+
lambda c: f"一张{c}的照片",
|
45 |
+
lambda c: f"{c}的涂鸦",
|
46 |
+
lambda c: f"{c}的近距离照片",
|
47 |
+
lambda c: f"{c}的折纸",
|
48 |
+
lambda c: f"{c}在视频游戏中",
|
49 |
+
lambda c: f"{c}的草图",
|
50 |
+
lambda c: f"{c}的涂鸦照",
|
51 |
+
lambda c: f"{c}的折纸形状",
|
52 |
+
lambda c: f"低分辨率的{c}的照片",
|
53 |
+
lambda c: f"玩具{c}",
|
54 |
+
lambda c: f"{c}的副本",
|
55 |
+
lambda c: f"{c}的干净的照片",
|
56 |
+
lambda c: f"一张大{c}的照片",
|
57 |
+
lambda c: f"{c}的重现",
|
58 |
+
lambda c: f"一张漂亮的{c}的照片",
|
59 |
+
lambda c: f"一张奇怪的{c}的照片",
|
60 |
+
lambda c: f"模糊的{c}的照片",
|
61 |
+
lambda c: f"卡通{c}",
|
62 |
+
lambda c: f"{c}的艺术作品",
|
63 |
+
lambda c: f"{c}的素描",
|
64 |
+
lambda c: f"刺绣{c}",
|
65 |
+
lambda c: f"{c}的像素照",
|
66 |
+
lambda c: f"{c}的拍照",
|
67 |
+
lambda c: f"{c}的损坏的照片",
|
68 |
+
lambda c: f"高质量的{c}的照片",
|
69 |
+
lambda c: f"毛绒玩具{c}",
|
70 |
+
lambda c: f"漂亮的{c}的照片",
|
71 |
+
lambda c: f"小{c}的照片",
|
72 |
+
lambda c: f"照片是奇怪的{c}",
|
73 |
+
lambda c: f"漫画{c}",
|
74 |
+
lambda c: f"{c}的艺术照",
|
75 |
+
lambda c: f"{c}的图形",
|
76 |
+
lambda c: f"大{c}的照片",
|
77 |
+
lambda c: f"黑白的{c}的照片",
|
78 |
+
lambda c: f"{c}毛绒玩具",
|
79 |
+
lambda c: f"一张{c}的深色照片",
|
80 |
+
lambda c: f"{c}的摄影图",
|
81 |
+
lambda c: f"{c}的涂鸦照",
|
82 |
+
lambda c: f"玩具形状的{c}",
|
83 |
+
lambda c: f"拍了{c}的照片",
|
84 |
+
lambda c: f"酷酷的{c}的照片",
|
85 |
+
lambda c: f"照片里的小{c}",
|
86 |
+
lambda c: f"{c}的刺青",
|
87 |
+
lambda c: f"{c}的可爱的照片",
|
88 |
+
lambda c: f"一张{c}可爱的照片",
|
89 |
+
lambda c: f"{c}可爱图片",
|
90 |
+
lambda c: f"{c}酷炫图片",
|
91 |
+
lambda c: f"一张{c}的酷炫的照片",
|
92 |
+
lambda c: f"一张{c}的酷炫图片",
|
93 |
+
lambda c: f"这是{c}",
|
94 |
+
lambda c: f"{c}的好看照片",
|
95 |
+
lambda c: f"一张{c}的好看的图片",
|
96 |
+
lambda c: f"{c}的好看图片",
|
97 |
+
lambda c: f"{c}的照片。",
|
98 |
+
lambda c: f"质量差的{c}的照片。",
|
99 |
+
lambda c: f"许多{c}的照片。",
|
100 |
+
lambda c: f"{c}的雕塑。",
|
101 |
+
lambda c: f"难以看到{c}的照片。",
|
102 |
+
lambda c: f"{c}的低分辨率照片。",
|
103 |
+
lambda c: f"{c}的渲染。",
|
104 |
+
lambda c: f"涂鸦{c}。",
|
105 |
+
lambda c: f"{c}的糟糕照片。",
|
106 |
+
lambda c: f"{c}的裁剪照片。",
|
107 |
+
lambda c: f"{c}的纹身。",
|
108 |
+
lambda c: f"{c}的刺绣照片。",
|
109 |
+
lambda c: f"很难看到{c}的照片。",
|
110 |
+
lambda c: f"{c}的明亮照片。",
|
111 |
+
lambda c: f"一张干净的{c}的照片。",
|
112 |
+
lambda c: f"一张包含{c}的照片。",
|
113 |
+
lambda c: f"{c}的深色照片。",
|
114 |
+
lambda c: f"{c}的手绘画。",
|
115 |
+
lambda c: f"我的{c}的照片。",
|
116 |
+
lambda c: f"不自然的{c}的照片。",
|
117 |
+
lambda c: f"一张酷的{c}的照片。",
|
118 |
+
lambda c: f"{c}的特写照片。",
|
119 |
+
lambda c: f"{c}的黑白照片。",
|
120 |
+
lambda c: f"一幅{c}的画。",
|
121 |
+
lambda c: f"一幅{c}的绘画。",
|
122 |
+
lambda c: f"一张{c}的像素照片。",
|
123 |
+
lambda c: f"{c}的雕像。",
|
124 |
+
lambda c: f"一张{c}的明亮照片。",
|
125 |
+
lambda c: f"{c}的裁剪照片。",
|
126 |
+
lambda c: f"人造的{c}的照片。",
|
127 |
+
lambda c: f"一张关于{c}的照片。",
|
128 |
+
lambda c: f"损坏的{c}的jpeg照片。",
|
129 |
+
lambda c: f"{c}的模糊照片。",
|
130 |
+
lambda c: f"{c}的相片。",
|
131 |
+
lambda c: f"一张{c}的好照片。",
|
132 |
+
lambda c: f"{c}的渲染照。",
|
133 |
+
lambda c: f"视频游戏中的{c}。",
|
134 |
+
lambda c: f"一张{c}的照片。",
|
135 |
+
lambda c: f"{c}的涂鸦。",
|
136 |
+
lambda c: f"{c}的近距离照片。",
|
137 |
+
lambda c: f"{c}的折纸。",
|
138 |
+
lambda c: f"{c}在视频游戏中。",
|
139 |
+
lambda c: f"{c}的草图。",
|
140 |
+
lambda c: f"{c}的涂鸦照。",
|
141 |
+
lambda c: f"{c}的折纸形状。",
|
142 |
+
lambda c: f"低分辨率的{c}的照片。",
|
143 |
+
lambda c: f"玩具{c}。",
|
144 |
+
lambda c: f"{c}的副本。",
|
145 |
+
lambda c: f"{c}的干净的照片。",
|
146 |
+
lambda c: f"一张大{c}的照片。",
|
147 |
+
lambda c: f"{c}的重现。",
|
148 |
+
lambda c: f"一张漂亮的{c}的照片。",
|
149 |
+
lambda c: f"一张奇怪的{c}的照片。",
|
150 |
+
lambda c: f"模糊的{c}的照片。",
|
151 |
+
lambda c: f"卡通{c}。",
|
152 |
+
lambda c: f"{c}的艺术作品。",
|
153 |
+
lambda c: f"{c}的素描。",
|
154 |
+
lambda c: f"刺绣{c}。",
|
155 |
+
lambda c: f"{c}的像素照。",
|
156 |
+
lambda c: f"{c}的拍照。",
|
157 |
+
lambda c: f"{c}的损坏的照片。",
|
158 |
+
lambda c: f"高质量的{c}的照片。",
|
159 |
+
lambda c: f"毛绒玩具{c}。",
|
160 |
+
lambda c: f"漂亮的{c}的照片。",
|
161 |
+
lambda c: f"小{c}的照片。",
|
162 |
+
lambda c: f"照片是奇怪的{c}。",
|
163 |
+
lambda c: f"漫画{c}。",
|
164 |
+
lambda c: f"{c}的艺术照。",
|
165 |
+
lambda c: f"{c}的图形。",
|
166 |
+
lambda c: f"大{c}的照片。",
|
167 |
+
lambda c: f"黑白的{c}的照片。",
|
168 |
+
lambda c: f"{c}毛绒玩具。",
|
169 |
+
lambda c: f"一张{c}的深色照片。",
|
170 |
+
lambda c: f"{c}的摄影图。",
|
171 |
+
lambda c: f"{c}的涂鸦照。",
|
172 |
+
lambda c: f"玩具形状的{c}。",
|
173 |
+
lambda c: f"拍了{c}的照片。",
|
174 |
+
lambda c: f"酷酷的{c}的照片。",
|
175 |
+
lambda c: f"照片里的小{c}。",
|
176 |
+
lambda c: f"{c}的刺青。",
|
177 |
+
lambda c: f"{c}的可爱的照片。",
|
178 |
+
lambda c: f"一张{c}可爱的照片。",
|
179 |
+
lambda c: f"{c}可爱图片。",
|
180 |
+
lambda c: f"{c}酷炫图片。",
|
181 |
+
lambda c: f"一张{c}的酷炫的照片。",
|
182 |
+
lambda c: f"一张{c}的酷炫图片。",
|
183 |
+
lambda c: f"这是{c}。",
|
184 |
+
lambda c: f"{c}的好看照片。",
|
185 |
+
lambda c: f"一张{c}的好看的图片。",
|
186 |
+
lambda c: f"{c}的好看图片。",
|
187 |
+
lambda c: f"一种叫{c}的花的照片",
|
188 |
+
lambda c: f"一种叫{c}的食物的照片",
|
189 |
+
lambda c: f"{c}的卫星照片"
|
190 |
+
]
|
191 |
+
|
192 |
+
normal_templates = [lambda c: f"{c}的图片"]
|
193 |
+
|
194 |
+
flower_templates = [
|
195 |
+
lambda c: f"一种叫{c}的花的照片",
|
196 |
+
lambda c: f"一种叫{c}的花卉的照片",
|
197 |
+
lambda c: f"一种叫{c}的花朵的照片",
|
198 |
+
lambda c: f"一种叫{c}的鲜花的照片",
|
199 |
+
lambda c: f"一种叫{c}的花的高清图",
|
200 |
+
lambda c: f"一种叫{c}的花卉的高清图",
|
201 |
+
lambda c: f"一种叫{c}的花朵的高清图",
|
202 |
+
lambda c: f"一种叫{c}的鲜花的高清图",
|
203 |
+
lambda c: f"一种叫{c}的花的模糊图片",
|
204 |
+
lambda c: f"一种叫{c}的花朵的模糊图片",
|
205 |
+
lambda c: f"一种叫{c}的花卉的模糊图片",
|
206 |
+
lambda c: f"一种叫{c}的鲜花的模糊图片",
|
207 |
+
lambda c: f"一种叫{c}的花的缩放图片",
|
208 |
+
lambda c: f"一种叫{c}的花朵的缩放图片",
|
209 |
+
lambda c: f"一种叫{c}的花卉的缩放图片",
|
210 |
+
lambda c: f"一种叫{c}的鲜花的缩放图片",
|
211 |
+
lambda c: f"一种叫{c}的花的摄影图",
|
212 |
+
lambda c: f"一种叫{c}的花卉的摄影图",
|
213 |
+
lambda c: f"一种叫{c}的花朵的摄影图",
|
214 |
+
lambda c: f"一种叫{c}的鲜花的摄影图",
|
215 |
+
lambda c: f"一种叫{c}的花的近距离照片",
|
216 |
+
lambda c: f"一种叫{c}的花朵的近距离照片",
|
217 |
+
lambda c: f"一种叫{c}的花卉的近距离照片",
|
218 |
+
lambda c: f"一种叫{c}的鲜花的近距离照片",
|
219 |
+
lambda c: f"一种叫{c}的花的裁剪照片",
|
220 |
+
lambda c: f"一种叫{c}的花朵的裁剪照片",
|
221 |
+
lambda c: f"一种叫{c}的花卉的裁剪照片",
|
222 |
+
lambda c: f"一种叫{c}的鲜花的裁剪照片",
|
223 |
+
lambda c: f"一种叫{c}的花的好看的图片",
|
224 |
+
lambda c: f"一种叫{c}的花朵的好看的图片",
|
225 |
+
lambda c: f"一种叫{c}的花卉的好看的图片",
|
226 |
+
lambda c: f"一种叫{c}的鲜花的好看的图片",
|
227 |
+
]
|
228 |
+
|
229 |
+
food_templates = [
|
230 |
+
lambda c: f"一种叫{c}的食物的照片",
|
231 |
+
lambda c: f"一种叫{c}的美食的照片",
|
232 |
+
lambda c: f"一种叫{c}的菜的照片",
|
233 |
+
lambda c: f"一种叫{c}的食物的高清图",
|
234 |
+
lambda c: f"一种叫{c}的美食的高清图",
|
235 |
+
lambda c: f"一种叫{c}的菜的高清图",
|
236 |
+
lambda c: f"一种叫{c}的食物的模糊图片",
|
237 |
+
lambda c: f"一种叫{c}的美食的模糊图片",
|
238 |
+
lambda c: f"一种叫{c}的菜的模糊图片",
|
239 |
+
lambda c: f"一种叫{c}的食物的缩放图片",
|
240 |
+
lambda c: f"一种叫{c}的美食的缩放图片",
|
241 |
+
lambda c: f"一种叫{c}的菜的缩放图片",
|
242 |
+
lambda c: f"一种叫{c}的食物的摄影图",
|
243 |
+
lambda c: f"一种叫{c}的美食的摄影图",
|
244 |
+
lambda c: f"一种叫{c}的菜的摄影图",
|
245 |
+
lambda c: f"一种叫{c}的食物的近距离照片",
|
246 |
+
lambda c: f"一种叫{c}的美食的近距离照片",
|
247 |
+
lambda c: f"一种叫{c}的菜的近距离照片",
|
248 |
+
lambda c: f"一种叫{c}的食物的���剪照片",
|
249 |
+
lambda c: f"一种叫{c}的美食的裁剪照片",
|
250 |
+
lambda c: f"一种叫{c}的菜的裁剪照片",
|
251 |
+
]
|
252 |
+
|
253 |
+
aircraft_templates = [
|
254 |
+
lambda c: f"{c},飞机的照片",
|
255 |
+
lambda c: f"{c},飞机的高清图",
|
256 |
+
lambda c: f"{c},飞机的模糊图片",
|
257 |
+
lambda c: f"{c},飞机的缩放图片",
|
258 |
+
lambda c: f"{c},飞机的摄影图",
|
259 |
+
lambda c: f"{c},战斗机的照片",
|
260 |
+
lambda c: f"{c},战斗机的高清图",
|
261 |
+
lambda c: f"{c},战斗机的模糊图片",
|
262 |
+
lambda c: f"{c},战斗机的缩放图片",
|
263 |
+
lambda c: f"{c},战斗机的摄影图",
|
264 |
+
lambda c: f"{c},老飞机的照片",
|
265 |
+
lambda c: f"{c},老飞机的高清图",
|
266 |
+
lambda c: f"{c},老飞机的模糊图片",
|
267 |
+
lambda c: f"{c},老飞机的缩放图片",
|
268 |
+
lambda c: f"{c},老飞机的摄影图",
|
269 |
+
lambda c: f"{c},大飞机的照片",
|
270 |
+
lambda c: f"{c},大飞机的高清图",
|
271 |
+
lambda c: f"{c},大飞机的模糊图片",
|
272 |
+
lambda c: f"{c},大飞机的缩放图片",
|
273 |
+
lambda c: f"{c},大飞机的摄影图",
|
274 |
+
lambda c: f"{c},小飞机的照片",
|
275 |
+
lambda c: f"{c},小飞机的高清图",
|
276 |
+
lambda c: f"{c},小飞机的模糊图片",
|
277 |
+
lambda c: f"{c},小飞机的缩放图片",
|
278 |
+
lambda c: f"{c},小飞机的摄影图",
|
279 |
+
lambda c: f"{c},军用飞机的照片",
|
280 |
+
lambda c: f"{c},军用飞机的高清图",
|
281 |
+
lambda c: f"{c},军用飞机的模糊图片",
|
282 |
+
lambda c: f"{c},军用飞机的缩放图片",
|
283 |
+
lambda c: f"{c},军用飞机的摄影图",
|
284 |
+
lambda c: f"{c},运输机的照片",
|
285 |
+
lambda c: f"{c},运输机的高清图",
|
286 |
+
lambda c: f"{c},运输机的模糊图片",
|
287 |
+
lambda c: f"{c},运输机的缩放图片",
|
288 |
+
lambda c: f"{c},运输机的摄影图",
|
289 |
+
lambda c: f"{c},公务机的照片",
|
290 |
+
lambda c: f"{c},公务机的高清图",
|
291 |
+
lambda c: f"{c},公务机的模糊图片",
|
292 |
+
lambda c: f"{c},公务机的缩放图片",
|
293 |
+
lambda c: f"{c},公务机的摄影图",
|
294 |
+
lambda c: f"{c},客机的照片",
|
295 |
+
lambda c: f"{c},客机的高清图",
|
296 |
+
lambda c: f"{c},客机的模糊图片",
|
297 |
+
lambda c: f"{c},客机的缩放图片",
|
298 |
+
lambda c: f"{c},客机的摄影图",
|
299 |
+
lambda c: f"{c},喷气机的照片",
|
300 |
+
lambda c: f"{c},喷气机的高清图",
|
301 |
+
lambda c: f"{c},喷气机的模糊图片",
|
302 |
+
lambda c: f"{c},喷气机的缩放图片",
|
303 |
+
lambda c: f"{c},喷气机的摄影图",
|
304 |
+
lambda c: f"一种叫{c}的飞机的照片",
|
305 |
+
lambda c: f"一种叫{c}的飞机的高清图",
|
306 |
+
lambda c: f"一种叫{c}的飞机的模糊图片",
|
307 |
+
lambda c: f"一种叫{c}的飞机的缩放图片",
|
308 |
+
lambda c: f"一种叫{c}的飞机的摄影图",
|
309 |
+
lambda c: f"一种叫{c}的战斗机的照片",
|
310 |
+
lambda c: f"一种叫{c}的战斗机的高清图",
|
311 |
+
lambda c: f"一种叫{c}的战斗机的模糊图片",
|
312 |
+
lambda c: f"一种叫{c}的战斗机的缩放图片",
|
313 |
+
lambda c: f"一种叫{c}的战斗机的摄影图",
|
314 |
+
lambda c: f"一种叫{c}的老飞机的照片",
|
315 |
+
lambda c: f"一种叫{c}的老飞机的高清图",
|
316 |
+
lambda c: f"一种叫{c}的老飞机的模糊图片",
|
317 |
+
lambda c: f"一种叫{c}的老飞机的缩放图片",
|
318 |
+
lambda c: f"一种叫{c}的老飞机的摄影图",
|
319 |
+
lambda c: f"一种叫{c}的大飞机的照片",
|
320 |
+
lambda c: f"一种叫{c}的大飞机的高清图",
|
321 |
+
lambda c: f"一种叫{c}的大飞机的模糊图片",
|
322 |
+
lambda c: f"一种叫{c}的大飞机的缩放图片",
|
323 |
+
lambda c: f"一种叫{c}的大飞机的摄影图",
|
324 |
+
lambda c: f"一种叫{c}的小飞机的照片",
|
325 |
+
lambda c: f"一种叫{c}的小飞机的高清图",
|
326 |
+
lambda c: f"一种叫{c}的小飞机的模糊图片",
|
327 |
+
lambda c: f"一种叫{c}的小飞机的缩放图片",
|
328 |
+
lambda c: f"一种叫{c}的小飞机的摄影图",
|
329 |
+
lambda c: f"一种叫{c}的军用飞机的照片",
|
330 |
+
lambda c: f"一种叫{c}的军用飞机的高清图",
|
331 |
+
lambda c: f"一种叫{c}的军用飞机的模糊图片",
|
332 |
+
lambda c: f"一种叫{c}的军用飞机的缩放图片",
|
333 |
+
lambda c: f"一种叫{c}的军用飞机的摄影图",
|
334 |
+
lambda c: f"一种叫{c}的运输机的照片",
|
335 |
+
lambda c: f"一种叫{c}的运输机的高清图",
|
336 |
+
lambda c: f"一种叫{c}的运输机的模糊图片",
|
337 |
+
lambda c: f"一种叫{c}的运输机的缩放图片",
|
338 |
+
lambda c: f"一种叫{c}的运输机的摄影图",
|
339 |
+
lambda c: f"一种叫{c}的公务机的照片",
|
340 |
+
lambda c: f"一种叫{c}的公务机的高清图",
|
341 |
+
lambda c: f"一种叫{c}的公务机的模糊图片",
|
342 |
+
lambda c: f"一种叫{c}的公务机的缩放图片",
|
343 |
+
lambda c: f"一种叫{c}的公务机的摄影图",
|
344 |
+
lambda c: f"一种叫{c}的客机的照片",
|
345 |
+
lambda c: f"一种叫{c}的客机的高清图",
|
346 |
+
lambda c: f"一种叫{c}的客机的模糊图片",
|
347 |
+
lambda c: f"一种叫{c}的客机的缩放图片",
|
348 |
+
lambda c: f"一种叫{c}的客机的摄影图",
|
349 |
+
lambda c: f"一种叫{c}的喷气机的照片",
|
350 |
+
lambda c: f"���种叫{c}的喷气机的高清图",
|
351 |
+
lambda c: f"一种叫{c}的喷气机的模糊图片",
|
352 |
+
lambda c: f"一种叫{c}的喷气机的缩放图片",
|
353 |
+
lambda c: f"一种叫{c}的喷气机的摄影图",
|
354 |
+
]
|
355 |
+
|
356 |
+
eurosat_templates = [
|
357 |
+
lambda c: f"一张{c}的卫星照片",
|
358 |
+
lambda c: f"{c}的卫星照片",
|
359 |
+
lambda c: f"一张{c}的高清卫星照片",
|
360 |
+
lambda c: f"{c}的高清卫星照片",
|
361 |
+
lambda c: f"一张{c}的清晰的卫星照片",
|
362 |
+
lambda c: f"{c}的清晰的卫星照片",
|
363 |
+
lambda c: f"一张{c}的高质量的卫星照片",
|
364 |
+
lambda c: f"{c}的高质量的卫星照片",
|
365 |
+
lambda c: f"一张{c}的卫星图",
|
366 |
+
lambda c: f"{c}的卫星图",
|
367 |
+
lambda c: f"一张{c}的高清卫星图",
|
368 |
+
lambda c: f"{c}的高清卫星图",
|
369 |
+
lambda c: f"一张{c}的清晰的卫星图",
|
370 |
+
lambda c: f"{c}的清晰的卫星图",
|
371 |
+
lambda c: f"一张{c}的高质量的卫星图",
|
372 |
+
lambda c: f"{c}的高质量的卫星图",
|
373 |
+
lambda c: f"一张{c}的卫星图片",
|
374 |
+
lambda c: f"{c}的卫星图片",
|
375 |
+
lambda c: f"一张{c}的高清卫星图片",
|
376 |
+
lambda c: f"{c}的高清卫星图片",
|
377 |
+
lambda c: f"一张{c}的清晰的卫星图片",
|
378 |
+
lambda c: f"{c}的清晰的卫星图片",
|
379 |
+
lambda c: f"一张{c}的高质量的卫星图片",
|
380 |
+
lambda c: f"{c}的高质量的卫星图片",
|
381 |
+
]
|
382 |
+
|
383 |
+
hatefulmemes_templates = [
|
384 |
+
lambda c: f"一个{c}",
|
385 |
+
lambda c: f"{c}",
|
386 |
+
]
|
387 |
+
|
388 |
+
kitti_templates = [
|
389 |
+
lambda c: f"照片里{c}",
|
390 |
+
lambda c: f"图片里{c}",
|
391 |
+
lambda c: f"{c}",
|
392 |
+
]
|
393 |
+
|
394 |
+
cars_templates = [
|
395 |
+
lambda c: f"一张{c}的照片",
|
396 |
+
lambda c: f"一张我的{c}的照片",
|
397 |
+
lambda c: f"我爱我的{c}",
|
398 |
+
lambda c: f"一张我肮脏的{c}的照片",
|
399 |
+
lambda c: f"一张我干净的{c}的照片",
|
400 |
+
lambda c: f"一张我新买的{c}的照片",
|
401 |
+
lambda c: f"一张我旧的{c}的照片",
|
402 |
+
]
|
403 |
+
|
404 |
+
dtd_templates = [
|
405 |
+
lambda c: f"一张{c}纹理的照片",
|
406 |
+
lambda c: f"一张{c}图案的照片",
|
407 |
+
lambda c: f"一张{c}物体的照片",
|
408 |
+
lambda c: f"一张{c}纹理的图片",
|
409 |
+
lambda c: f"一张{c}图案的图片",
|
410 |
+
lambda c: f"一张{c}物体的图片",
|
411 |
+
]
|
412 |
+
|
413 |
+
country211_templates = [
|
414 |
+
lambda c: f"一张在{c}拍的照片",
|
415 |
+
lambda c: f"一张在{c}旅行时拍的照片",
|
416 |
+
lambda c: f"一张我家乡{c}的照片",
|
417 |
+
lambda c: f"一张展示{c}风光的照片",
|
418 |
+
]
|
419 |
+
|
420 |
+
patch_templates = [
|
421 |
+
lambda c: f"一张{c}的医疗照片",
|
422 |
+
lambda c: f"一张{c}的ct照片",
|
423 |
+
lambda c: f"一张{c}的化验照片",
|
424 |
+
]
|
425 |
+
|
426 |
+
pet_templates = [
|
427 |
+
lambda c: f"一种叫{c}的宠物的照片",
|
428 |
+
lambda c: f"一种叫{c}的宠物的图片",
|
429 |
+
lambda c: f"一种叫{c}的宠物的可爱图片",
|
430 |
+
lambda c: f"一种叫{c}的宠物的高清图片",
|
431 |
+
lambda c: f"一种叫{c}的宠物的模糊图片",
|
432 |
+
lambda c: f"一种叫{c}的宠物的特写照片",
|
433 |
+
]
|
434 |
+
|
435 |
+
cifar100_templates = [
|
436 |
+
lambda c: f"一张{c}的照片",
|
437 |
+
lambda c: f"一张{c}的模糊照片",
|
438 |
+
lambda c: f"一张{c}",
|
439 |
+
lambda c: f"一张{c}的低对比度照片",
|
440 |
+
lambda c: f"一张{c}的高对比度照片",
|
441 |
+
lambda c: f"一张{c}的好照片",
|
442 |
+
lambda c: f"一张小{c}的照片",
|
443 |
+
lambda c: f"一张大{c}的照片",
|
444 |
+
lambda c: f"一张{c}的黑白照片",
|
445 |
+
lambda c: f"一张{c}的低对比度的照片",
|
446 |
+
lambda c: f"一张{c}的高对比度的照片",
|
447 |
+
]
|
448 |
+
|
449 |
+
caltech101_templates = [
|
450 |
+
lambda c: f"{c}的照片",
|
451 |
+
lambda c: f"{c}的绘画",
|
452 |
+
lambda c: f"{c}的塑料",
|
453 |
+
lambda c: f"{c}的雕像",
|
454 |
+
lambda c: f"{c}的草图",
|
455 |
+
lambda c: f"{c}的刺青",
|
456 |
+
lambda c: f"{c}的玩具",
|
457 |
+
lambda c: f"{c}的演绎",
|
458 |
+
lambda c: f"{c}的装饰",
|
459 |
+
lambda c: f"{c}的卡通画",
|
460 |
+
lambda c: f"{c}在游戏中",
|
461 |
+
lambda c: f"一个豪华的{c}.",
|
462 |
+
lambda c: f"{c}的折纸",
|
463 |
+
lambda c: f"{c}的艺术画",
|
464 |
+
lambda c: f"{c}的涂鸦画",
|
465 |
+
lambda c: f"{c}的画",
|
466 |
+
]
|
467 |
+
|
468 |
+
fer_templates = [
|
469 |
+
lambda c: f"一张表情{c}的照片",
|
470 |
+
lambda c: f"一张表达{c}情绪的照片",
|
471 |
+
lambda c: f"一张看起来很{c}的照片",
|
472 |
+
lambda c: f"他的脸看起来{c}",
|
473 |
+
lambda c: f"他们看起来很{c}",
|
474 |
+
]
|
eval/data.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from PIL import Image
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
+
import torch
|
10 |
+
import lmdb
|
11 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
|
12 |
+
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
13 |
+
from torch.utils.data.distributed import DistributedSampler
|
14 |
+
from torch.utils.data.sampler import SequentialSampler
|
15 |
+
import torchvision.datasets as datasets
|
16 |
+
from clip import tokenize
|
17 |
+
|
18 |
+
|
19 |
+
def _convert_to_rgb(image):
|
20 |
+
return image.convert('RGB')
|
21 |
+
|
22 |
+
|
23 |
+
def _preprocess_text(text):
|
24 |
+
# adapt the text to Chinese BERT vocab
|
25 |
+
text = text.lower().replace("“", "\"").replace("”", "\"")
|
26 |
+
return text
|
27 |
+
|
28 |
+
|
29 |
+
class EvalTxtDataset(Dataset):
|
30 |
+
def __init__(self, jsonl_filename, max_txt_length=24):
|
31 |
+
assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename)
|
32 |
+
|
33 |
+
logging.debug(f'Loading jsonl data from {jsonl_filename}.')
|
34 |
+
self.texts = []
|
35 |
+
with open(jsonl_filename, "r", encoding="utf-8") as fin:
|
36 |
+
for line in fin:
|
37 |
+
obj = json.loads(line.strip())
|
38 |
+
text_id = obj['text_id']
|
39 |
+
text = obj['text']
|
40 |
+
self.texts.append((text_id, text))
|
41 |
+
logging.debug(f'Finished loading jsonl data from {jsonl_filename}.')
|
42 |
+
|
43 |
+
self.max_txt_length = max_txt_length
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.texts)
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
text_id, text = self.texts[idx]
|
50 |
+
text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0]
|
51 |
+
return text_id, text
|
52 |
+
|
53 |
+
|
54 |
+
class EvalImgDataset(Dataset):
|
55 |
+
def __init__(self, lmdb_imgs, resolution=224):
|
56 |
+
assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs)
|
57 |
+
|
58 |
+
logging.debug(f'Loading image LMDB from {lmdb_imgs}.')
|
59 |
+
|
60 |
+
self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
|
61 |
+
self.txn_imgs = self.env_imgs.begin(buffers=True)
|
62 |
+
self.cursor_imgs = self.txn_imgs.cursor()
|
63 |
+
self.iter_imgs = iter(self.cursor_imgs)
|
64 |
+
self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8'))
|
65 |
+
logging.info("The specified LMDB directory contains {} images.".format(self.number_images))
|
66 |
+
|
67 |
+
self.transform = self._build_transform(resolution)
|
68 |
+
|
69 |
+
def _build_transform(self, resolution):
|
70 |
+
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
71 |
+
return Compose([
|
72 |
+
Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC),
|
73 |
+
_convert_to_rgb,
|
74 |
+
ToTensor(),
|
75 |
+
normalize,
|
76 |
+
])
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return self.number_images
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
img_id, image_b64 = next(self.iter_imgs)
|
83 |
+
if img_id == b"num_images":
|
84 |
+
img_id, image_b64 = next(self.iter_imgs)
|
85 |
+
|
86 |
+
img_id = img_id.tobytes()
|
87 |
+
image_b64 = image_b64.tobytes()
|
88 |
+
|
89 |
+
img_id = int(img_id.decode(encoding="utf8", errors="ignore"))
|
90 |
+
image_b64 = image_b64.decode(encoding="utf8", errors="ignore")
|
91 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) # already resized
|
92 |
+
image = self.transform(image)
|
93 |
+
|
94 |
+
return img_id, image
|
95 |
+
|
96 |
+
|
97 |
+
@dataclass
|
98 |
+
class DataInfo:
|
99 |
+
dataloader: DataLoader
|
100 |
+
sampler: DistributedSampler
|
101 |
+
|
102 |
+
|
103 |
+
def get_eval_txt_dataset(args, max_txt_length=24):
|
104 |
+
input_filename = args.text_data
|
105 |
+
dataset = EvalTxtDataset(
|
106 |
+
input_filename,
|
107 |
+
max_txt_length=max_txt_length)
|
108 |
+
num_samples = len(dataset)
|
109 |
+
sampler = SequentialSampler(dataset)
|
110 |
+
|
111 |
+
dataloader = DataLoader(
|
112 |
+
dataset,
|
113 |
+
batch_size=args.text_batch_size,
|
114 |
+
num_workers=0,
|
115 |
+
pin_memory=True,
|
116 |
+
sampler=sampler,
|
117 |
+
drop_last=False,
|
118 |
+
)
|
119 |
+
dataloader.num_samples = num_samples
|
120 |
+
dataloader.num_batches = len(dataloader)
|
121 |
+
|
122 |
+
return DataInfo(dataloader, sampler)
|
123 |
+
|
124 |
+
|
125 |
+
def fetch_resolution(vision_model):
|
126 |
+
# fetch the resolution from the vision model config
|
127 |
+
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json"
|
128 |
+
with open(vision_model_config_file, 'r') as fv:
|
129 |
+
model_info = json.load(fv)
|
130 |
+
return model_info["image_resolution"]
|
131 |
+
|
132 |
+
|
133 |
+
def get_eval_img_dataset(args):
|
134 |
+
lmdb_imgs = args.image_data
|
135 |
+
dataset = EvalImgDataset(
|
136 |
+
lmdb_imgs, resolution=fetch_resolution(args.vision_model))
|
137 |
+
num_samples = len(dataset)
|
138 |
+
sampler = SequentialSampler(dataset)
|
139 |
+
|
140 |
+
dataloader = DataLoader(
|
141 |
+
dataset,
|
142 |
+
batch_size=args.img_batch_size,
|
143 |
+
num_workers=0,
|
144 |
+
pin_memory=True,
|
145 |
+
sampler=sampler,
|
146 |
+
drop_last=False,
|
147 |
+
)
|
148 |
+
dataloader.num_samples = num_samples
|
149 |
+
dataloader.num_batches = len(dataloader)
|
150 |
+
|
151 |
+
return DataInfo(dataloader, sampler)
|
152 |
+
|
153 |
+
|
154 |
+
def get_zeroshot_dataset(args, preprocess_fn):
|
155 |
+
dataset = datasets.ImageFolder(args.datapath, transform=preprocess_fn)
|
156 |
+
|
157 |
+
dataloader = torch.utils.data.DataLoader(
|
158 |
+
dataset,
|
159 |
+
batch_size=args.img_batch_size,
|
160 |
+
num_workers=args.num_workers,
|
161 |
+
sampler=None,
|
162 |
+
)
|
163 |
+
|
164 |
+
return DataInfo(dataloader, None)
|
eval/evaluation.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script computes the recall scores given the ground-truth annotations and predictions.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import json
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import string
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
|
13 |
+
NUM_K = 10
|
14 |
+
|
15 |
+
def read_submission(submit_path, reference, k=5):
|
16 |
+
# check whether the path of submitted file exists
|
17 |
+
if not os.path.exists(submit_path):
|
18 |
+
raise Exception("The submission file is not found!")
|
19 |
+
|
20 |
+
submission_dict = {}
|
21 |
+
ref_qids = set(reference.keys())
|
22 |
+
|
23 |
+
with open(submit_path, encoding="utf-8") as fin:
|
24 |
+
for line in fin:
|
25 |
+
line = line.strip()
|
26 |
+
try:
|
27 |
+
pred_obj = json.loads(line)
|
28 |
+
except:
|
29 |
+
raise Exception('Cannot parse this line into json object: {}'.format(line))
|
30 |
+
if "text_id" not in pred_obj:
|
31 |
+
raise Exception('There exists one line not containing text_id: {}'.format(line))
|
32 |
+
if not isinstance(pred_obj['text_id'], int):
|
33 |
+
raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
|
34 |
+
qid = pred_obj["text_id"]
|
35 |
+
if "image_ids" not in pred_obj:
|
36 |
+
raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
|
37 |
+
image_ids = pred_obj["image_ids"]
|
38 |
+
if not isinstance(image_ids, list):
|
39 |
+
raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
|
40 |
+
# check whether there are K products for each text
|
41 |
+
if len(image_ids) != k:
|
42 |
+
raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))
|
43 |
+
# check whether there exist an invalid prediction for any text
|
44 |
+
for rank, image_id in enumerate(image_ids):
|
45 |
+
if not isinstance(image_id, int):
|
46 |
+
raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
|
47 |
+
# check whether there are duplicate predicted products for a single text
|
48 |
+
if len(set(image_ids)) != k:
|
49 |
+
raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
|
50 |
+
submission_dict[qid] = image_ids # here we save the list of product ids
|
51 |
+
|
52 |
+
# check if any text is missing in the submission
|
53 |
+
pred_qids = set(submission_dict.keys())
|
54 |
+
nopred_qids = ref_qids - pred_qids
|
55 |
+
if len(nopred_qids) != 0:
|
56 |
+
raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))
|
57 |
+
|
58 |
+
return submission_dict
|
59 |
+
|
60 |
+
|
61 |
+
def dump_2_json(info, path):
|
62 |
+
with open(path, 'w', encoding="utf-8") as output_json_file:
|
63 |
+
json.dump(info, output_json_file)
|
64 |
+
|
65 |
+
|
66 |
+
def report_error_msg(detail, showMsg, out_p):
|
67 |
+
error_dict=dict()
|
68 |
+
error_dict['errorDetail']=detail
|
69 |
+
error_dict['errorMsg']=showMsg
|
70 |
+
error_dict['score']=0
|
71 |
+
error_dict['scoreJson']={}
|
72 |
+
error_dict['success']=False
|
73 |
+
dump_2_json(error_dict,out_p)
|
74 |
+
|
75 |
+
|
76 |
+
def report_score(r1, r5, r10, out_p):
|
77 |
+
result = dict()
|
78 |
+
result['success']=True
|
79 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
80 |
+
result['score'] = mean_recall * 100
|
81 |
+
result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
|
82 |
+
dump_2_json(result,out_p)
|
83 |
+
|
84 |
+
|
85 |
+
def read_reference(path):
|
86 |
+
fin = open(path, encoding="utf-8")
|
87 |
+
reference = dict()
|
88 |
+
for line in fin:
|
89 |
+
line = line.strip()
|
90 |
+
obj = json.loads(line)
|
91 |
+
reference[obj['text_id']] = obj['image_ids']
|
92 |
+
return reference
|
93 |
+
|
94 |
+
def compute_score(golden_file, predict_file):
|
95 |
+
# read ground-truth
|
96 |
+
reference = read_reference(golden_file)
|
97 |
+
|
98 |
+
# read predictions
|
99 |
+
k = 10
|
100 |
+
predictions = read_submission(predict_file, reference, k)
|
101 |
+
|
102 |
+
# compute score for each text
|
103 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
104 |
+
for qid in reference.keys():
|
105 |
+
ground_truth_ids = set(reference[qid])
|
106 |
+
top10_pred_ids = predictions[qid]
|
107 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
108 |
+
r1_stat += 1
|
109 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
110 |
+
r5_stat += 1
|
111 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
112 |
+
r10_stat += 1
|
113 |
+
# the higher score, the better
|
114 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
115 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
116 |
+
result = [mean_recall, r1, r5, r10]
|
117 |
+
result = [score * 100 for score in result]
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
if __name__=="__main__":
|
122 |
+
# the path of answer json file (eg. test_queries_answers.jsonl)
|
123 |
+
standard_path = sys.argv[1]
|
124 |
+
# the path of prediction file (eg. example_pred.jsonl)
|
125 |
+
submit_path = sys.argv[2]
|
126 |
+
# the score will be dumped into this output json file
|
127 |
+
out_path = sys.argv[3]
|
128 |
+
|
129 |
+
print("Read standard from %s" % standard_path)
|
130 |
+
print("Read user submit file from %s" % submit_path)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# read ground-truth
|
134 |
+
reference = read_reference(standard_path)
|
135 |
+
|
136 |
+
# read predictions
|
137 |
+
k = 10
|
138 |
+
predictions = read_submission(submit_path, reference, k)
|
139 |
+
|
140 |
+
# compute score for each text
|
141 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
142 |
+
for qid in reference.keys():
|
143 |
+
ground_truth_ids = set(reference[qid])
|
144 |
+
top10_pred_ids = predictions[qid]
|
145 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
146 |
+
r1_stat += 1
|
147 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
148 |
+
r5_stat += 1
|
149 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
150 |
+
r10_stat += 1
|
151 |
+
# the higher score, the better
|
152 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
153 |
+
report_score(r1, r5, r10, out_path)
|
154 |
+
print("The evaluation finished successfully.")
|
155 |
+
except Exception as e:
|
156 |
+
report_error_msg(e.args[0], e.args[0], out_path)
|
157 |
+
print("The evaluation failed: {}".format(e.args[0]))
|
eval/evaluation_tr.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script computes the recall scores given the ground-truth annotations and predictions.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import json
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import string
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
|
13 |
+
NUM_K = 10
|
14 |
+
|
15 |
+
def read_submission(submit_path, reference, k=5):
|
16 |
+
# check whether the path of submitted file exists
|
17 |
+
if not os.path.exists(submit_path):
|
18 |
+
raise Exception("The submission file is not found!")
|
19 |
+
|
20 |
+
submission_dict = {}
|
21 |
+
ref_image_ids = set(reference.keys())
|
22 |
+
|
23 |
+
with open(submit_path, encoding="utf-8") as fin:
|
24 |
+
for line in fin:
|
25 |
+
line = line.strip()
|
26 |
+
try:
|
27 |
+
pred_obj = json.loads(line)
|
28 |
+
except:
|
29 |
+
raise Exception('Cannot parse this line into json object: {}'.format(line))
|
30 |
+
if "image_id" not in pred_obj:
|
31 |
+
raise Exception('There exists one line not containing image_id: {}'.format(line))
|
32 |
+
if not isinstance(pred_obj['image_id'], int):
|
33 |
+
raise Exception('Found an invalid image_id {}, it should be an integer (not string), please check your schema'.format(pred_obj['image_id']))
|
34 |
+
image_id = pred_obj['image_id']
|
35 |
+
if "text_ids" not in pred_obj:
|
36 |
+
raise Exception('There exists one line not containing the predicted text_ids: {}'.format(line))
|
37 |
+
text_ids = pred_obj["text_ids"]
|
38 |
+
if not isinstance(text_ids, list):
|
39 |
+
raise Exception('The text_ids field of image_id {} is not a list, please check your schema'.format(image_id))
|
40 |
+
# check whether there are K products for each text
|
41 |
+
if len(text_ids) != k:
|
42 |
+
raise Exception('Image_id {} has wrong number of predicted text_ids! Require {}, but {} founded.'.format(image_id, k, len(text_ids)))
|
43 |
+
# check whether there exist an invalid prediction for any text
|
44 |
+
for rank, text_id in enumerate(text_ids):
|
45 |
+
if not isinstance(text_id, int):
|
46 |
+
raise Exception('Image_id {} has an invalid predicted text_id {} at rank {}, it should be an integer (not string), please check your schema'.format(image_id, text_id, rank + 1))
|
47 |
+
# check whether there are duplicate predicted products for a single text
|
48 |
+
if len(set(text_ids)) != k:
|
49 |
+
raise Exception('Image_id {} has duplicate products in your prediction. Pleace check again!'.format(image_id))
|
50 |
+
submission_dict[image_id] = text_ids # here we save the list of product ids
|
51 |
+
|
52 |
+
# check if any text is missing in the submission
|
53 |
+
pred_image_ids = set(submission_dict.keys())
|
54 |
+
nopred_image_ids = ref_image_ids - pred_image_ids
|
55 |
+
if len(nopred_image_ids) != 0:
|
56 |
+
raise Exception('The following image_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_image_ids])))
|
57 |
+
|
58 |
+
return submission_dict
|
59 |
+
|
60 |
+
|
61 |
+
def dump_2_json(info, path):
|
62 |
+
with open(path, 'w', encoding="utf-8") as output_json_file:
|
63 |
+
json.dump(info, output_json_file)
|
64 |
+
|
65 |
+
|
66 |
+
def report_error_msg(detail, showMsg, out_p):
|
67 |
+
error_dict=dict()
|
68 |
+
error_dict['errorDetail']=detail
|
69 |
+
error_dict['errorMsg']=showMsg
|
70 |
+
error_dict['score']=0
|
71 |
+
error_dict['scoreJson']={}
|
72 |
+
error_dict['success']=False
|
73 |
+
dump_2_json(error_dict,out_p)
|
74 |
+
|
75 |
+
|
76 |
+
def report_score(r1, r5, r10, out_p):
|
77 |
+
result = dict()
|
78 |
+
result['success']=True
|
79 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
80 |
+
result['score'] = mean_recall * 100
|
81 |
+
result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
|
82 |
+
dump_2_json(result,out_p)
|
83 |
+
|
84 |
+
|
85 |
+
def read_reference(path):
|
86 |
+
fin = open(path, encoding="utf-8")
|
87 |
+
reference = dict()
|
88 |
+
for line in fin:
|
89 |
+
line = line.strip()
|
90 |
+
obj = json.loads(line)
|
91 |
+
reference[obj['image_id']] = obj['text_ids']
|
92 |
+
return reference
|
93 |
+
|
94 |
+
def compute_score(golden_file, predict_file):
|
95 |
+
# read ground-truth
|
96 |
+
reference = read_reference(golden_file)
|
97 |
+
|
98 |
+
# read predictions
|
99 |
+
k = 10
|
100 |
+
predictions = read_submission(predict_file, reference, k)
|
101 |
+
|
102 |
+
# compute score for each text
|
103 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
104 |
+
for qid in reference.keys():
|
105 |
+
ground_truth_ids = set(reference[qid])
|
106 |
+
top10_pred_ids = predictions[qid]
|
107 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
108 |
+
r1_stat += 1
|
109 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
110 |
+
r5_stat += 1
|
111 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
112 |
+
r10_stat += 1
|
113 |
+
# the higher score, the better
|
114 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
115 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
116 |
+
result = [mean_recall, r1, r5, r10]
|
117 |
+
result = [score * 100 for score in result]
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
if __name__=="__main__":
|
122 |
+
# the path of answer json file (eg. test_queries_answers.jsonl)
|
123 |
+
standard_path = sys.argv[1]
|
124 |
+
# the path of prediction file (eg. example_pred.jsonl)
|
125 |
+
submit_path = sys.argv[2]
|
126 |
+
# the score will be dumped into this output json file
|
127 |
+
out_path = sys.argv[3]
|
128 |
+
|
129 |
+
print("Read standard from %s" % standard_path)
|
130 |
+
print("Read user submit file from %s" % submit_path)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# read ground-truth
|
134 |
+
reference = read_reference(standard_path)
|
135 |
+
|
136 |
+
# read predictions
|
137 |
+
k = 10
|
138 |
+
predictions = read_submission(submit_path, reference, k)
|
139 |
+
|
140 |
+
# compute score for each text
|
141 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
142 |
+
for qid in reference.keys():
|
143 |
+
ground_truth_ids = set(reference[qid])
|
144 |
+
top10_pred_ids = predictions[qid]
|
145 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
146 |
+
r1_stat += 1
|
147 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
148 |
+
r5_stat += 1
|
149 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
150 |
+
r10_stat += 1
|
151 |
+
# the higher score, the better
|
152 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
153 |
+
report_score(r1, r5, r10, out_path)
|
154 |
+
print("The evaluation finished successfully.")
|
155 |
+
except Exception as e:
|
156 |
+
report_error_msg(e.args[0], e.args[0], out_path)
|
157 |
+
print("The evaluation failed: {}".format(e.args[0]))
|
eval/extract_features.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script extracts image and text features for evaluation. (with single-GPU)
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import json
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from clip.model import convert_weights, CLIP
|
16 |
+
from eval.data import get_eval_img_dataset, get_eval_txt_dataset
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
'--extract-image-feats',
|
22 |
+
action="store_true",
|
23 |
+
default=False,
|
24 |
+
help="Whether to extract image features."
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
'--extract-text-feats',
|
28 |
+
action="store_true",
|
29 |
+
default=False,
|
30 |
+
help="Whether to extract text features."
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
'--image-data',
|
34 |
+
type=str,
|
35 |
+
default="../Multimodal_Retrieval/lmdb/test/imgs",
|
36 |
+
help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings."
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
'--text-data',
|
40 |
+
type=str,
|
41 |
+
default="../Multimodal_Retrieval/test_texts.jsonl",
|
42 |
+
help="If --extract-text-feats is True, specify the path of input text Jsonl file."
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
'--image-feat-output-path',
|
46 |
+
type=str,
|
47 |
+
default=None,
|
48 |
+
help="If --extract-image-feats is True, specify the path of output image features."
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
'--text-feat-output-path',
|
52 |
+
type=str,
|
53 |
+
default=None,
|
54 |
+
help="If --extract-image-feats is True, specify the path of output text features."
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--img-batch-size", type=int, default=64, help="Image batch size."
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--text-batch-size", type=int, default=64, help="Text batch size."
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)."
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--resume",
|
67 |
+
default=None,
|
68 |
+
type=str,
|
69 |
+
help="path to latest checkpoint (default: none)",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--precision",
|
73 |
+
choices=["amp", "fp16", "fp32"],
|
74 |
+
default="amp",
|
75 |
+
help="Floating point precition."
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--vision-model",
|
79 |
+
choices=["ViT-B-16", "ViT-L-14", "RN50"],
|
80 |
+
default="ViT-B-16",
|
81 |
+
help="Name of the vision backbone to use.",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--text-model",
|
85 |
+
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
|
86 |
+
default="RoBERTa-wwm-ext-base-chinese",
|
87 |
+
help="Name of the text backbone to use.",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--debug",
|
91 |
+
default=False,
|
92 |
+
action="store_true",
|
93 |
+
help="If true, more information is logged."
|
94 |
+
)
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
return args
|
98 |
+
|
99 |
+
# Used by https://github.com/openai/CLIP/issues/83 but not below.
|
100 |
+
# Keeping it incase needed.
|
101 |
+
def convert_models_to_fp32(model):
|
102 |
+
for p in model.parameters():
|
103 |
+
p.data = p.data.float()
|
104 |
+
if p.grad:
|
105 |
+
p.grad.data = p.grad.data.float()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
args = parse_args()
|
110 |
+
|
111 |
+
assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!"
|
112 |
+
|
113 |
+
# Log params.
|
114 |
+
print("Params:")
|
115 |
+
for name in sorted(vars(args)):
|
116 |
+
val = getattr(args, name)
|
117 |
+
print(f" {name}: {val}")
|
118 |
+
|
119 |
+
args.gpu = 0
|
120 |
+
torch.cuda.set_device(args.gpu)
|
121 |
+
|
122 |
+
# Initialize the model.
|
123 |
+
vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
|
124 |
+
print('Loading vision model config from', vision_model_config_file)
|
125 |
+
assert os.path.exists(vision_model_config_file)
|
126 |
+
|
127 |
+
text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
|
128 |
+
print('Loading text model config from', text_model_config_file)
|
129 |
+
assert os.path.exists(text_model_config_file)
|
130 |
+
|
131 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
132 |
+
model_info = json.load(fv)
|
133 |
+
if isinstance(model_info['vision_layers'], str):
|
134 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
135 |
+
for k, v in json.load(ft).items():
|
136 |
+
model_info[k] = v
|
137 |
+
|
138 |
+
model = CLIP(**model_info)
|
139 |
+
convert_weights(model)
|
140 |
+
|
141 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
142 |
+
if args.precision == "amp" or args.precision == "fp32":
|
143 |
+
convert_models_to_fp32(model)
|
144 |
+
model.cuda(args.gpu)
|
145 |
+
if args.precision == "fp16":
|
146 |
+
convert_weights(model)
|
147 |
+
|
148 |
+
# Get data.
|
149 |
+
if args.extract_image_feats:
|
150 |
+
print("Preparing image inference dataset.")
|
151 |
+
img_data = get_eval_img_dataset(args)
|
152 |
+
if args.extract_text_feats:
|
153 |
+
print("Preparing text inference dataset.")
|
154 |
+
text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length)
|
155 |
+
|
156 |
+
# Resume from a checkpoint.
|
157 |
+
print("Begin to load model checkpoint from {}.".format(args.resume))
|
158 |
+
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
|
159 |
+
# Map model to be loaded to specified single gpu.
|
160 |
+
loc = "cuda:{}".format(args.gpu)
|
161 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
162 |
+
start_epoch = checkpoint["epoch"]
|
163 |
+
sd = checkpoint["state_dict"]
|
164 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
165 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
166 |
+
model.load_state_dict(sd)
|
167 |
+
print(
|
168 |
+
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
|
169 |
+
)
|
170 |
+
|
171 |
+
# Make inference for texts
|
172 |
+
if args.extract_text_feats:
|
173 |
+
print('Make inference for texts...')
|
174 |
+
if args.text_feat_output_path is None:
|
175 |
+
args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6])
|
176 |
+
write_cnt = 0
|
177 |
+
with open(args.text_feat_output_path, "w") as fout:
|
178 |
+
model.eval()
|
179 |
+
dataloader = text_data.dataloader
|
180 |
+
with torch.no_grad():
|
181 |
+
for batch in tqdm(dataloader):
|
182 |
+
text_ids, texts = batch
|
183 |
+
texts = texts.cuda(args.gpu, non_blocking=True)
|
184 |
+
text_features = model(None, texts)
|
185 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
186 |
+
for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()):
|
187 |
+
fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature})))
|
188 |
+
write_cnt += 1
|
189 |
+
print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path))
|
190 |
+
|
191 |
+
# Make inference for images
|
192 |
+
if args.extract_image_feats:
|
193 |
+
print('Make inference for images...')
|
194 |
+
if args.image_feat_output_path is None:
|
195 |
+
# by default, we store the image features under the same directory with the text features
|
196 |
+
args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs"))
|
197 |
+
write_cnt = 0
|
198 |
+
with open(args.image_feat_output_path, "w") as fout:
|
199 |
+
model.eval()
|
200 |
+
dataloader = img_data.dataloader
|
201 |
+
with torch.no_grad():
|
202 |
+
for batch in tqdm(dataloader):
|
203 |
+
image_ids, images = batch
|
204 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
205 |
+
image_features = model(images, None)
|
206 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
207 |
+
for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()):
|
208 |
+
fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature})))
|
209 |
+
write_cnt += 1
|
210 |
+
print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path))
|
211 |
+
|
212 |
+
print("Done!")
|
eval/make_topk_predictions.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs text-to-image prediction file for evaluation.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import numpy
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--image-feats',
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="Specify the path of image features."
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
'--text-feats',
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="Specify the path of text features."
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
'--top-k',
|
30 |
+
type=int,
|
31 |
+
default=10,
|
32 |
+
help="Specify the k value of top-k predictions."
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
'--eval-batch-size',
|
36 |
+
type=int,
|
37 |
+
default=32768,
|
38 |
+
help="Specify the image-side batch size when computing the inner products, default to 8192"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--output',
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help="Specify the output jsonl prediction filepath."
|
45 |
+
)
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
args = parse_args()
|
50 |
+
|
51 |
+
# Log params.
|
52 |
+
print("Params:")
|
53 |
+
for name in sorted(vars(args)):
|
54 |
+
val = getattr(args, name)
|
55 |
+
print(f" {name}: {val}")
|
56 |
+
|
57 |
+
print("Begin to load image features...")
|
58 |
+
image_ids = []
|
59 |
+
image_feats = []
|
60 |
+
with open(args.image_feats, "r") as fin:
|
61 |
+
for line in tqdm(fin):
|
62 |
+
obj = json.loads(line.strip())
|
63 |
+
image_ids.append(obj['image_id'])
|
64 |
+
image_feats.append(obj['feature'])
|
65 |
+
image_feats_array = np.array(image_feats, dtype=np.float32)
|
66 |
+
print("Finished loading image features.")
|
67 |
+
|
68 |
+
print("Begin to compute top-{} predictions for texts...".format(args.top_k))
|
69 |
+
with open(args.output, "w") as fout:
|
70 |
+
with open(args.text_feats, "r") as fin:
|
71 |
+
for line in tqdm(fin):
|
72 |
+
obj = json.loads(line.strip())
|
73 |
+
text_id = obj['text_id']
|
74 |
+
text_feat = obj['feature']
|
75 |
+
score_tuples = []
|
76 |
+
text_feat_tensor = torch.tensor([text_feat], dtype=torch.float).cuda() # [1, feature_dim]
|
77 |
+
idx = 0
|
78 |
+
while idx < len(image_ids):
|
79 |
+
img_feats_tensor = torch.from_numpy(image_feats_array[idx : min(idx + args.eval_batch_size, len(image_ids))]).cuda() # [batch_size, feature_dim]
|
80 |
+
batch_scores = text_feat_tensor @ img_feats_tensor.t() # [1, batch_size]
|
81 |
+
for image_id, score in zip(image_ids[idx : min(idx + args.eval_batch_size, len(image_ids))], batch_scores.squeeze(0).tolist()):
|
82 |
+
score_tuples.append((image_id, score))
|
83 |
+
idx += args.eval_batch_size
|
84 |
+
top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
|
85 |
+
fout.write("{}\n".format(json.dumps({"text_id": text_id, "image_ids": [entry[0] for entry in top_k_predictions]})))
|
86 |
+
|
87 |
+
print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
|
88 |
+
print("Done!")
|
eval/make_topk_predictions_tr.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs image-to-text retrieval prediction file for evaluation.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import numpy
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--image-feats',
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="Specify the path of image features."
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
'--text-feats',
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="Specify the path of text features."
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
'--top-k',
|
30 |
+
type=int,
|
31 |
+
default=10,
|
32 |
+
help="Specify the k value of top-k predictions."
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
'--eval-batch-size',
|
36 |
+
type=int,
|
37 |
+
default=32768,
|
38 |
+
help="Specify the image-side batch size when computing the inner products, default to 8192"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--output',
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help="Specify the output jsonl prediction filepath."
|
45 |
+
)
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
args = parse_args()
|
50 |
+
|
51 |
+
# Log params.
|
52 |
+
print("Params:")
|
53 |
+
for name in sorted(vars(args)):
|
54 |
+
val = getattr(args, name)
|
55 |
+
print(f" {name}: {val}")
|
56 |
+
|
57 |
+
print("Begin to load text features...")
|
58 |
+
text_ids = []
|
59 |
+
text_feats = []
|
60 |
+
with open(args.text_feats, "r") as fin:
|
61 |
+
for line in tqdm(fin):
|
62 |
+
obj = json.loads(line.strip())
|
63 |
+
text_ids.append(obj['text_id'])
|
64 |
+
text_feats.append(obj['feature'])
|
65 |
+
text_feats_array = np.array(text_feats, dtype=np.float32)
|
66 |
+
print("Finished loading text features.")
|
67 |
+
|
68 |
+
print("Begin to compute top-{} predictions for images...".format(args.top_k))
|
69 |
+
with open(args.output, "w") as fout:
|
70 |
+
with open(args.image_feats, "r") as fin:
|
71 |
+
for line in tqdm(fin):
|
72 |
+
obj = json.loads(line.strip())
|
73 |
+
image_id = obj['image_id']
|
74 |
+
image_feat = obj['feature']
|
75 |
+
score_tuples = []
|
76 |
+
image_feat_tensor = torch.tensor([image_feat], dtype=torch.float).cuda() # [1, feature_dim]
|
77 |
+
idx = 0
|
78 |
+
while idx < len(text_ids):
|
79 |
+
text_feats_tensor = torch.from_numpy(text_feats_array[idx : min(idx + args.eval_batch_size, len(text_ids))]).cuda() # [batch_size, feature_dim]
|
80 |
+
batch_scores = image_feat_tensor @ text_feats_tensor.t() # [1, batch_size]
|
81 |
+
for text_id, score in zip(text_ids[idx : min(idx + args.eval_batch_size, len(text_ids))], batch_scores.squeeze(0).tolist()):
|
82 |
+
score_tuples.append((text_id, score))
|
83 |
+
idx += args.eval_batch_size
|
84 |
+
top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
|
85 |
+
fout.write("{}\n".format(json.dumps({"image_id": image_id, "text_ids": [entry[0] for entry in top_k_predictions]})))
|
86 |
+
|
87 |
+
print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
|
88 |
+
print("Done!")
|
eval/transform_ir_annotation_to_tr.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from tqdm import tqdm
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
'--input',
|
10 |
+
type=str,
|
11 |
+
required=True,
|
12 |
+
help="Input path of text-to-image Jsonl annotation file."
|
13 |
+
)
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
args = parse_args()
|
18 |
+
|
19 |
+
t2i_record = dict()
|
20 |
+
|
21 |
+
with open(args.input, "r", encoding="utf-8") as fin:
|
22 |
+
for line in tqdm(fin):
|
23 |
+
obj = json.loads(line.strip())
|
24 |
+
text_id = obj['text_id']
|
25 |
+
image_ids = obj['image_ids']
|
26 |
+
for image_id in image_ids:
|
27 |
+
if image_id not in t2i_record:
|
28 |
+
t2i_record[image_id] = []
|
29 |
+
t2i_record[image_id].append(text_id)
|
30 |
+
|
31 |
+
with open(args.input.replace(".jsonl", "") + ".tr.jsonl", "w", encoding="utf-8") as fout:
|
32 |
+
for image_id, text_ids in t2i_record.items():
|
33 |
+
out_obj = {"image_id": image_id, "text_ids": text_ids}
|
34 |
+
fout.write("{}\n".format(json.dumps(out_obj)))
|
35 |
+
|
36 |
+
print("Done!")
|
eval/zeroshot_evaluation.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script performs zero-shot evaluation on ImageNet-1K. (with single-GPU)
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
import json
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from clip.model import convert_weights, CLIP
|
15 |
+
from clip import tokenize
|
16 |
+
from clip.utils import image_transform
|
17 |
+
from eval.data import get_zeroshot_dataset, _preprocess_text
|
18 |
+
from eval.cvinw_zeroshot_templates import (
|
19 |
+
openai_templates,
|
20 |
+
flower_templates,
|
21 |
+
food_templates,
|
22 |
+
aircraft_templates,
|
23 |
+
eurosat_templates,
|
24 |
+
country211_templates,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument(
|
31 |
+
"--vision-model",
|
32 |
+
choices=["ViT-B-16", "ViT-L-14", "RN50"],
|
33 |
+
default="ViT-B-16",
|
34 |
+
help="Name of the vision backbone to use.",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--text-model",
|
38 |
+
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
|
39 |
+
default="RoBERTa-wwm-ext-base-chinese",
|
40 |
+
help="Name of the text backbone to use.",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--precision",
|
44 |
+
choices=["amp", "fp16", "fp32"],
|
45 |
+
default="amp",
|
46 |
+
help="Floating point precition."
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--label-file",
|
50 |
+
type=str,
|
51 |
+
help="file for labels",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--datapath",
|
55 |
+
type=str,
|
56 |
+
required=True,
|
57 |
+
help="Path to the test set for conducting zero shot evaluation.",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--dataset",
|
61 |
+
type=str,
|
62 |
+
default="imagenet",
|
63 |
+
help="Specified dataset.",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--index",
|
67 |
+
type=str,
|
68 |
+
default="",
|
69 |
+
help="Specify image paths.",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--save-dir",
|
73 |
+
type=str,
|
74 |
+
default="",
|
75 |
+
help="Specified dataset.",
|
76 |
+
)
|
77 |
+
# parser.add_argument(
|
78 |
+
# "--imagenet-val",
|
79 |
+
# type=str,
|
80 |
+
# required=True,
|
81 |
+
# help="Path to imagenet val set for conducting zero shot evaluation.",
|
82 |
+
# )
|
83 |
+
parser.add_argument(
|
84 |
+
"--img-batch-size", type=int, default=64, help="Image batch size."
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--context-length",
|
88 |
+
type=int,
|
89 |
+
default=52,
|
90 |
+
help="The maximum length of input text (include [CLS] & [SEP] tokens)."
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--resume",
|
94 |
+
default=None,
|
95 |
+
type=str,
|
96 |
+
help="path to latest checkpoint (default: none)",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--num-workers", type=int, default=4, help="Number of workers for ImageNet dataloader."
|
100 |
+
)
|
101 |
+
args = parser.parse_args()
|
102 |
+
|
103 |
+
return args
|
104 |
+
|
105 |
+
# Used by https://github.com/openai/CLIP/issues/83 but not below.
|
106 |
+
# Keeping it incase needed.
|
107 |
+
def convert_models_to_fp32(model):
|
108 |
+
for p in model.parameters():
|
109 |
+
p.data = p.data.float()
|
110 |
+
if p.grad:
|
111 |
+
p.grad.data = p.grad.data.float()
|
112 |
+
|
113 |
+
|
114 |
+
def zero_shot_classifier(model, classnames, templates, args):
|
115 |
+
with torch.no_grad():
|
116 |
+
zeroshot_weights = []
|
117 |
+
for classname in tqdm(classnames):
|
118 |
+
texts = [_preprocess_text(template(classname)) for template in templates] # format with class
|
119 |
+
texts = tokenize(texts, context_length=args.context_length).to(args.gpu) # tokenize
|
120 |
+
class_embeddings = model(None, texts)
|
121 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
122 |
+
class_embedding = class_embeddings.mean(dim=0)
|
123 |
+
class_embedding /= class_embedding.norm()
|
124 |
+
zeroshot_weights.append(class_embedding)
|
125 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
|
126 |
+
return zeroshot_weights
|
127 |
+
|
128 |
+
|
129 |
+
def accuracy(output, target, topk=(1,)):
|
130 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
131 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
132 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
133 |
+
|
134 |
+
|
135 |
+
def run(model, classifier, dataloader, args):
|
136 |
+
total_logits = []
|
137 |
+
total_targets = []
|
138 |
+
with torch.no_grad():
|
139 |
+
top1, top5, n = 0.0, 0.0, 0.0
|
140 |
+
for images, target in tqdm(dataloader):
|
141 |
+
images = images.to(args.gpu)
|
142 |
+
target = target.to(args.gpu)
|
143 |
+
total_targets.append(target)
|
144 |
+
|
145 |
+
# predict
|
146 |
+
image_features = model(images, None)
|
147 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
148 |
+
logits = (100.0 * image_features @ classifier).softmax(dim=-1)
|
149 |
+
total_logits.append(logits)
|
150 |
+
|
151 |
+
# measure accuracy
|
152 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 1))
|
153 |
+
top1 += acc1
|
154 |
+
n += images.size(0)
|
155 |
+
|
156 |
+
outputs = torch.cat(total_logits, dim=0)
|
157 |
+
targets = torch.cat(total_targets, dim=0)
|
158 |
+
|
159 |
+
if getattr(args, "index", ""):
|
160 |
+
print("Use index to rearrange the logits...")
|
161 |
+
with open(args.index, "r", encoding="utf-8") as f:
|
162 |
+
index = json.load(f)
|
163 |
+
print(index)
|
164 |
+
outputs = outputs[index]
|
165 |
+
targets = targets[index]
|
166 |
+
print(targets)
|
167 |
+
|
168 |
+
top1 = top1 / n
|
169 |
+
|
170 |
+
return top1, outputs
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
args = parse_args()
|
175 |
+
|
176 |
+
# Log params.
|
177 |
+
print("Params:")
|
178 |
+
for name in sorted(vars(args)):
|
179 |
+
val = getattr(args, name)
|
180 |
+
print(f" {name}: {val}")
|
181 |
+
|
182 |
+
args.gpu = 0
|
183 |
+
torch.cuda.set_device(args.gpu)
|
184 |
+
|
185 |
+
# Initialize the model.
|
186 |
+
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
|
187 |
+
print('Loading vision model config from', vision_model_config_file)
|
188 |
+
assert os.path.exists(vision_model_config_file)
|
189 |
+
|
190 |
+
text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
|
191 |
+
print('Loading text model config from', text_model_config_file)
|
192 |
+
assert os.path.exists(text_model_config_file)
|
193 |
+
|
194 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
195 |
+
model_info = json.load(fv)
|
196 |
+
if isinstance(model_info['vision_layers'], str):
|
197 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
198 |
+
for k, v in json.load(ft).items():
|
199 |
+
model_info[k] = v
|
200 |
+
|
201 |
+
model = CLIP(**model_info)
|
202 |
+
convert_weights(model)
|
203 |
+
|
204 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
205 |
+
if args.precision == "amp" or args.precision == "fp32":
|
206 |
+
convert_models_to_fp32(model)
|
207 |
+
model.cuda(args.gpu)
|
208 |
+
if args.precision == "fp16":
|
209 |
+
convert_weights(model)
|
210 |
+
|
211 |
+
# Get eval data.
|
212 |
+
print("Preparing zeroshot dataset.")
|
213 |
+
data = {}
|
214 |
+
print(f"{model_info['image_resolution']}")
|
215 |
+
data[args.dataset] = get_zeroshot_dataset(
|
216 |
+
args, image_transform(model_info["image_resolution"])
|
217 |
+
)
|
218 |
+
|
219 |
+
# Resume from a checkpoint.
|
220 |
+
print("Begin to load model checkpoint from {}.".format(args.resume))
|
221 |
+
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
|
222 |
+
# Map model to be loaded to specified single gpu.
|
223 |
+
loc = "cuda:{}".format(args.gpu)
|
224 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
225 |
+
start_epoch = checkpoint["epoch"]
|
226 |
+
sd = checkpoint["state_dict"]
|
227 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
228 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
229 |
+
model.load_state_dict(sd, strict=False)
|
230 |
+
print(
|
231 |
+
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
|
232 |
+
)
|
233 |
+
|
234 |
+
# Compute ensembled class embeddings
|
235 |
+
print('Building zero-shot classifier')
|
236 |
+
|
237 |
+
model.eval()
|
238 |
+
|
239 |
+
f = open(args.label_file, "r", encoding="utf8")
|
240 |
+
classnames = [line.strip() for line in f.readlines()]
|
241 |
+
|
242 |
+
template_dict = {
|
243 |
+
"fgvc-aircraft-2013b-variants102": aircraft_templates,
|
244 |
+
"food-101": food_templates,
|
245 |
+
"oxford-flower-102": flower_templates,
|
246 |
+
"eurosat_clip": eurosat_templates,
|
247 |
+
"resisc45_clip": eurosat_templates,
|
248 |
+
"country211": country211_templates,
|
249 |
+
"openai": openai_templates,
|
250 |
+
}
|
251 |
+
if args.dataset in template_dict.keys():
|
252 |
+
templates = template_dict[args.dataset]
|
253 |
+
else:
|
254 |
+
templates = template_dict['openai']
|
255 |
+
|
256 |
+
# Make inference and evaluation
|
257 |
+
print('Using classifier')
|
258 |
+
classifier = zero_shot_classifier(model, classnames, templates, args)
|
259 |
+
results = {}
|
260 |
+
top1, logits = run(model, classifier, data[args.dataset].dataloader, args)
|
261 |
+
|
262 |
+
|
263 |
+
results["zeroshot-top1"] = top1
|
264 |
+
|
265 |
+
print('Result:')
|
266 |
+
print(", ".join(["{}: {}".format(k, v) for k, v in results.items()]))
|
267 |
+
print('Finished.')
|
examples/pokemon.jpeg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
tqdm
|
3 |
+
six
|
4 |
+
timm
|
5 |
+
lmdb==1.3.0
|
6 |
+
torch>=1.7.1
|
7 |
+
torchvision
|
8 |
+
webdataset
|
9 |
+
pandas
|
10 |
+
transformers
|
scripts/zeroshot_eval.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Usage: see example script below.
|
4 |
+
# bash run_scripts/zeroshot_eval.sh 0 \
|
5 |
+
# ${path_to_dataset} ${dataset_name} \
|
6 |
+
# ViT-B-16 RoBERTa-wwm-ext-base-chinese \
|
7 |
+
# ${ckpt_path}
|
8 |
+
|
9 |
+
# only supports single-GPU inference
|
10 |
+
export CUDA_VISIBLE_DEVICES=${1}
|
11 |
+
export PYTHONPATH=${PYTHONPATH}:`pwd`/QA-CLIP-main
|
12 |
+
|
13 |
+
path=${2}
|
14 |
+
dataset=${3}
|
15 |
+
datapath=${path}
|
16 |
+
savedir=`pwd`/save_predictions
|
17 |
+
vision_model=${4} # ViT-B-16
|
18 |
+
text_model=${5}
|
19 |
+
resume=${6}
|
20 |
+
label_file=`pwd`/label_cn.txt
|
21 |
+
index=${7:-}
|
22 |
+
|
23 |
+
mkdir -p ${savedir}
|
24 |
+
|
25 |
+
python -u eval/zeroshot_evaluation.py \
|
26 |
+
--datapath="${datapath}" \
|
27 |
+
--label-file=${label_file} \
|
28 |
+
--save-dir=${savedir} \
|
29 |
+
--dataset=${dataset} \
|
30 |
+
--index=${index} \
|
31 |
+
--img-batch-size=64 \
|
32 |
+
--resume=${resume} \
|
33 |
+
--vision-model=${vision_model} \
|
34 |
+
--text-model=${text_model}
|