Spaces:
Build error
Build error
AK391
commited on
Commit
·
794924b
1
Parent(s):
bbbc401
files
Browse files- LICENSE.txt +12 -0
- configs/caption_coco.yaml +33 -0
- configs/med_config.json +21 -0
- configs/nlvr.yaml +21 -0
- configs/nocaps.yaml +15 -0
- configs/pretrain.yaml +27 -0
- configs/retrieval_coco.yaml +34 -0
- configs/retrieval_flickr.yaml +34 -0
- configs/vqa.yaml +25 -0
- data/__init__.py +101 -0
- data/coco_karpathy_dataset.py +126 -0
- data/flickr30k_dataset.py +93 -0
- data/nlvr_dataset.py +78 -0
- data/nocaps_dataset.py +32 -0
- data/pretrain_dataset.py +59 -0
- data/utils.py +112 -0
- data/vqa_dataset.py +88 -0
- eval_nocaps.py +118 -0
- models/__init__.py +0 -0
- models/blip.py +238 -0
- models/blip_nlvr.py +103 -0
- models/blip_pretrain.py +339 -0
- models/blip_retrieval.py +322 -0
- models/blip_vqa.py +186 -0
- models/med.py +955 -0
- models/nlvr_encoder.py +843 -0
- models/vit.py +305 -0
- pretrain.py +173 -0
- train_caption.py +206 -0
- train_nlvr.py +213 -0
- train_retrieval.py +345 -0
- train_vqa.py +202 -0
- transform/randaugment.py +340 -0
- utils.py +278 -0
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
configs/caption_coco.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
coco_gt_root: 'annotation/coco_gt'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
vit: 'base'
|
10 |
+
vit_grad_ckpt: False
|
11 |
+
vit_ckpt_layer: 0
|
12 |
+
batch_size: 32
|
13 |
+
init_lr: 1e-5
|
14 |
+
|
15 |
+
# vit: 'large'
|
16 |
+
# vit_grad_ckpt: True
|
17 |
+
# vit_ckpt_layer: 5
|
18 |
+
# batch_size: 16
|
19 |
+
# init_lr: 2e-6
|
20 |
+
|
21 |
+
image_size: 384
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
max_length: 20
|
25 |
+
min_length: 5
|
26 |
+
num_beams: 3
|
27 |
+
prompt: 'a picture of '
|
28 |
+
|
29 |
+
# optimizer
|
30 |
+
weight_decay: 0.05
|
31 |
+
min_lr: 0
|
32 |
+
max_epoch: 5
|
33 |
+
|
configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
configs/nlvr.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/NLVR2/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
6 |
+
|
7 |
+
#size of vit model; base or large
|
8 |
+
vit: 'base'
|
9 |
+
batch_size_train: 16
|
10 |
+
batch_size_test: 64
|
11 |
+
vit_grad_ckpt: False
|
12 |
+
vit_ckpt_layer: 0
|
13 |
+
max_epoch: 15
|
14 |
+
|
15 |
+
image_size: 384
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-5
|
20 |
+
min_lr: 0
|
21 |
+
|
configs/nocaps.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/nocaps/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
6 |
+
|
7 |
+
vit: 'base'
|
8 |
+
batch_size: 32
|
9 |
+
|
10 |
+
image_size: 384
|
11 |
+
|
12 |
+
max_length: 20
|
13 |
+
min_length: 5
|
14 |
+
num_beams: 3
|
15 |
+
prompt: 'a picture of '
|
configs/pretrain.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
2 |
+
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
3 |
+
]
|
4 |
+
laion_path: ''
|
5 |
+
|
6 |
+
# size of vit model; base or large
|
7 |
+
vit: 'base'
|
8 |
+
vit_grad_ckpt: False
|
9 |
+
vit_ckpt_layer: 0
|
10 |
+
|
11 |
+
image_size: 224
|
12 |
+
batch_size: 75
|
13 |
+
|
14 |
+
queue_size: 57600
|
15 |
+
alpha: 0.4
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-4
|
20 |
+
min_lr: 1e-6
|
21 |
+
warmup_lr: 1e-6
|
22 |
+
lr_decay_rate: 0.9
|
23 |
+
max_epoch: 20
|
24 |
+
warmup_steps: 3000
|
25 |
+
|
26 |
+
|
27 |
+
|
configs/retrieval_coco.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'coco'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 12
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 256
|
28 |
+
negative_all_rank: True
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/retrieval_flickr.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/flickr30k/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'flickr'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 10
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 128
|
28 |
+
negative_all_rank: False
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/vqa.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
2 |
+
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
3 |
+
train_files: ['vqa_train','vqa_val','vg_qa']
|
4 |
+
ann_root: 'annotation'
|
5 |
+
|
6 |
+
# set pretrained as a file path or an url
|
7 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
|
8 |
+
|
9 |
+
# size of vit model; base or large
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 16
|
12 |
+
batch_size_test: 32
|
13 |
+
vit_grad_ckpt: False
|
14 |
+
vit_ckpt_layer: 0
|
15 |
+
init_lr: 2e-5
|
16 |
+
|
17 |
+
image_size: 480
|
18 |
+
|
19 |
+
k_test: 128
|
20 |
+
inference: 'rank'
|
21 |
+
|
22 |
+
# optimizer
|
23 |
+
weight_decay: 0.05
|
24 |
+
min_lr: 0
|
25 |
+
max_epoch: 10
|
data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.transforms.functional import InterpolationMode
|
5 |
+
|
6 |
+
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
7 |
+
from data.nocaps_dataset import nocaps_eval
|
8 |
+
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
9 |
+
from data.vqa_dataset import vqa_dataset
|
10 |
+
from data.nlvr_dataset import nlvr_dataset
|
11 |
+
from data.pretrain_dataset import pretrain_dataset
|
12 |
+
from transform.randaugment import RandomAugment
|
13 |
+
|
14 |
+
def create_dataset(dataset, config, min_scale=0.5):
|
15 |
+
|
16 |
+
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
17 |
+
|
18 |
+
transform_train = transforms.Compose([
|
19 |
+
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
20 |
+
transforms.RandomHorizontalFlip(),
|
21 |
+
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
22 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
normalize,
|
25 |
+
])
|
26 |
+
transform_test = transforms.Compose([
|
27 |
+
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
normalize,
|
30 |
+
])
|
31 |
+
|
32 |
+
if dataset=='pretrain':
|
33 |
+
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
elif dataset=='caption_coco':
|
37 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
38 |
+
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
39 |
+
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
40 |
+
return train_dataset, val_dataset, test_dataset
|
41 |
+
|
42 |
+
elif dataset=='nocaps':
|
43 |
+
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
44 |
+
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
45 |
+
return val_dataset, test_dataset
|
46 |
+
|
47 |
+
elif dataset=='retrieval_coco':
|
48 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
49 |
+
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
50 |
+
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
51 |
+
return train_dataset, val_dataset, test_dataset
|
52 |
+
|
53 |
+
elif dataset=='retrieval_flickr':
|
54 |
+
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
55 |
+
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
56 |
+
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
57 |
+
return train_dataset, val_dataset, test_dataset
|
58 |
+
|
59 |
+
elif dataset=='vqa':
|
60 |
+
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
61 |
+
train_files = config['train_files'], split='train')
|
62 |
+
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
63 |
+
return train_dataset, test_dataset
|
64 |
+
|
65 |
+
elif dataset=='nlvr':
|
66 |
+
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
67 |
+
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
68 |
+
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
69 |
+
return train_dataset, val_dataset, test_dataset
|
70 |
+
|
71 |
+
|
72 |
+
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
73 |
+
samplers = []
|
74 |
+
for dataset,shuffle in zip(datasets,shuffles):
|
75 |
+
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
76 |
+
samplers.append(sampler)
|
77 |
+
return samplers
|
78 |
+
|
79 |
+
|
80 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
81 |
+
loaders = []
|
82 |
+
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
83 |
+
if is_train:
|
84 |
+
shuffle = (sampler is None)
|
85 |
+
drop_last = True
|
86 |
+
else:
|
87 |
+
shuffle = False
|
88 |
+
drop_last = False
|
89 |
+
loader = DataLoader(
|
90 |
+
dataset,
|
91 |
+
batch_size=bs,
|
92 |
+
num_workers=n_worker,
|
93 |
+
pin_memory=True,
|
94 |
+
sampler=sampler,
|
95 |
+
shuffle=shuffle,
|
96 |
+
collate_fn=collate_fn,
|
97 |
+
drop_last=drop_last,
|
98 |
+
)
|
99 |
+
loaders.append(loader)
|
100 |
+
return loaders
|
101 |
+
|
data/coco_karpathy_dataset.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class coco_karpathy_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
|
18 |
+
filename = 'coco_karpathy_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class coco_karpathy_caption_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
61 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.annotation)
|
71 |
+
|
72 |
+
def __getitem__(self, index):
|
73 |
+
|
74 |
+
ann = self.annotation[index]
|
75 |
+
|
76 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
77 |
+
image = Image.open(image_path).convert('RGB')
|
78 |
+
image = self.transform(image)
|
79 |
+
|
80 |
+
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
81 |
+
|
82 |
+
return image, int(img_id)
|
83 |
+
|
84 |
+
|
85 |
+
class coco_karpathy_retrieval_eval(Dataset):
|
86 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
87 |
+
'''
|
88 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
89 |
+
ann_root (string): directory to store the annotation file
|
90 |
+
split (string): val or test
|
91 |
+
'''
|
92 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
93 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
94 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
95 |
+
|
96 |
+
download_url(urls[split],ann_root)
|
97 |
+
|
98 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
99 |
+
self.transform = transform
|
100 |
+
self.image_root = image_root
|
101 |
+
|
102 |
+
self.text = []
|
103 |
+
self.image = []
|
104 |
+
self.txt2img = {}
|
105 |
+
self.img2txt = {}
|
106 |
+
|
107 |
+
txt_id = 0
|
108 |
+
for img_id, ann in enumerate(self.annotation):
|
109 |
+
self.image.append(ann['image'])
|
110 |
+
self.img2txt[img_id] = []
|
111 |
+
for i, caption in enumerate(ann['caption']):
|
112 |
+
self.text.append(pre_caption(caption,max_words))
|
113 |
+
self.img2txt[img_id].append(txt_id)
|
114 |
+
self.txt2img[txt_id] = img_id
|
115 |
+
txt_id += 1
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.annotation)
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
|
122 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
123 |
+
image = Image.open(image_path).convert('RGB')
|
124 |
+
image = self.transform(image)
|
125 |
+
|
126 |
+
return image, index
|
data/flickr30k_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class flickr30k_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
|
18 |
+
filename = 'flickr30k_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class flickr30k_retrieval_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
61 |
+
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
self.text = []
|
70 |
+
self.image = []
|
71 |
+
self.txt2img = {}
|
72 |
+
self.img2txt = {}
|
73 |
+
|
74 |
+
txt_id = 0
|
75 |
+
for img_id, ann in enumerate(self.annotation):
|
76 |
+
self.image.append(ann['image'])
|
77 |
+
self.img2txt[img_id] = []
|
78 |
+
for i, caption in enumerate(ann['caption']):
|
79 |
+
self.text.append(pre_caption(caption,max_words))
|
80 |
+
self.img2txt[img_id].append(txt_id)
|
81 |
+
self.txt2img[txt_id] = img_id
|
82 |
+
txt_id += 1
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.annotation)
|
86 |
+
|
87 |
+
def __getitem__(self, index):
|
88 |
+
|
89 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
90 |
+
image = Image.open(image_path).convert('RGB')
|
91 |
+
image = self.transform(image)
|
92 |
+
|
93 |
+
return image, index
|
data/nlvr_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.datasets.utils import download_url
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from data.utils import pre_caption
|
11 |
+
|
12 |
+
class nlvr_dataset(Dataset):
|
13 |
+
def __init__(self, transform, image_root, ann_root, split):
|
14 |
+
'''
|
15 |
+
image_root (string): Root directory of images
|
16 |
+
ann_root (string): directory to store the annotation file
|
17 |
+
split (string): train, val or test
|
18 |
+
'''
|
19 |
+
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
|
20 |
+
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
|
21 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
|
22 |
+
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
|
23 |
+
|
24 |
+
download_url(urls[split],ann_root)
|
25 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
26 |
+
|
27 |
+
self.transform = transform
|
28 |
+
self.image_root = image_root
|
29 |
+
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.annotation)
|
33 |
+
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
|
37 |
+
ann = self.annotation[index]
|
38 |
+
|
39 |
+
image0_path = os.path.join(self.image_root,ann['images'][0])
|
40 |
+
image0 = Image.open(image0_path).convert('RGB')
|
41 |
+
image0 = self.transform(image0)
|
42 |
+
|
43 |
+
image1_path = os.path.join(self.image_root,ann['images'][1])
|
44 |
+
image1 = Image.open(image1_path).convert('RGB')
|
45 |
+
image1 = self.transform(image1)
|
46 |
+
|
47 |
+
sentence = pre_caption(ann['sentence'], 40)
|
48 |
+
|
49 |
+
if ann['label']=='True':
|
50 |
+
label = 1
|
51 |
+
else:
|
52 |
+
label = 0
|
53 |
+
|
54 |
+
words = sentence.split(' ')
|
55 |
+
|
56 |
+
if 'left' not in words and 'right' not in words:
|
57 |
+
if random.random()<0.5:
|
58 |
+
return image0, image1, sentence, label
|
59 |
+
else:
|
60 |
+
return image1, image0, sentence, label
|
61 |
+
else:
|
62 |
+
if random.random()<0.5:
|
63 |
+
return image0, image1, sentence, label
|
64 |
+
else:
|
65 |
+
new_words = []
|
66 |
+
for word in words:
|
67 |
+
if word=='left':
|
68 |
+
new_words.append('right')
|
69 |
+
elif word=='right':
|
70 |
+
new_words.append('left')
|
71 |
+
else:
|
72 |
+
new_words.append(word)
|
73 |
+
|
74 |
+
sentence = ' '.join(new_words)
|
75 |
+
return image1, image0, sentence, label
|
76 |
+
|
77 |
+
|
78 |
+
|
data/nocaps_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class nocaps_eval(Dataset):
|
10 |
+
def __init__(self, transform, image_root, ann_root, split):
|
11 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
|
12 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
|
13 |
+
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
|
14 |
+
|
15 |
+
download_url(urls[split],ann_root)
|
16 |
+
|
17 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
18 |
+
self.transform = transform
|
19 |
+
self.image_root = image_root
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.annotation)
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
|
26 |
+
ann = self.annotation[index]
|
27 |
+
|
28 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
29 |
+
image = Image.open(image_path).convert('RGB')
|
30 |
+
image = self.transform(image)
|
31 |
+
|
32 |
+
return image, int(ann['img_id'])
|
data/pretrain_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from PIL import ImageFile
|
9 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
10 |
+
Image.MAX_IMAGE_PIXELS = None
|
11 |
+
|
12 |
+
from data.utils import pre_caption
|
13 |
+
import os,glob
|
14 |
+
|
15 |
+
class pretrain_dataset(Dataset):
|
16 |
+
def __init__(self, ann_file, laion_path, transform):
|
17 |
+
|
18 |
+
self.ann_pretrain = []
|
19 |
+
for f in ann_file:
|
20 |
+
print('loading '+f)
|
21 |
+
ann = json.load(open(f,'r'))
|
22 |
+
self.ann_pretrain += ann
|
23 |
+
|
24 |
+
self.laion_path = laion_path
|
25 |
+
if self.laion_path:
|
26 |
+
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
|
27 |
+
|
28 |
+
print('loading '+self.laion_files[0])
|
29 |
+
with open(self.laion_files[0],'r') as f:
|
30 |
+
self.ann_laion = json.load(f)
|
31 |
+
|
32 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
33 |
+
else:
|
34 |
+
self.annotation = self.ann_pretrain
|
35 |
+
|
36 |
+
self.transform = transform
|
37 |
+
|
38 |
+
|
39 |
+
def reload_laion(self, epoch):
|
40 |
+
n = epoch%len(self.laion_files)
|
41 |
+
print('loading '+self.laion_files[n])
|
42 |
+
with open(self.laion_files[n],'r') as f:
|
43 |
+
self.ann_laion = json.load(f)
|
44 |
+
|
45 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
46 |
+
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.annotation)
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
|
53 |
+
ann = self.annotation[index]
|
54 |
+
|
55 |
+
image = Image.open(ann['image']).convert('RGB')
|
56 |
+
image = self.transform(image)
|
57 |
+
caption = pre_caption(ann['caption'],30)
|
58 |
+
|
59 |
+
return image, caption
|
data/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import utils
|
9 |
+
|
10 |
+
def pre_caption(caption,max_words=50):
|
11 |
+
caption = re.sub(
|
12 |
+
r"([.!\"()*#:;~])",
|
13 |
+
' ',
|
14 |
+
caption.lower(),
|
15 |
+
)
|
16 |
+
caption = re.sub(
|
17 |
+
r"\s{2,}",
|
18 |
+
' ',
|
19 |
+
caption,
|
20 |
+
)
|
21 |
+
caption = caption.rstrip('\n')
|
22 |
+
caption = caption.strip(' ')
|
23 |
+
|
24 |
+
#truncate caption
|
25 |
+
caption_words = caption.split(' ')
|
26 |
+
if len(caption_words)>max_words:
|
27 |
+
caption = ' '.join(caption_words[:max_words])
|
28 |
+
|
29 |
+
return caption
|
30 |
+
|
31 |
+
def pre_question(question,max_ques_words=50):
|
32 |
+
question = re.sub(
|
33 |
+
r"([.!\"()*#:;~])",
|
34 |
+
'',
|
35 |
+
question.lower(),
|
36 |
+
)
|
37 |
+
question = question.rstrip(' ')
|
38 |
+
|
39 |
+
#truncate question
|
40 |
+
question_words = question.split(' ')
|
41 |
+
if len(question_words)>max_ques_words:
|
42 |
+
question = ' '.join(question_words[:max_ques_words])
|
43 |
+
|
44 |
+
return question
|
45 |
+
|
46 |
+
|
47 |
+
def save_result(result, result_dir, filename, remove_duplicate=''):
|
48 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
49 |
+
final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
50 |
+
|
51 |
+
json.dump(result,open(result_file,'w'))
|
52 |
+
|
53 |
+
dist.barrier()
|
54 |
+
|
55 |
+
if utils.is_main_process():
|
56 |
+
# combine results from all processes
|
57 |
+
result = []
|
58 |
+
|
59 |
+
for rank in range(utils.get_world_size()):
|
60 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
61 |
+
res = json.load(open(result_file,'r'))
|
62 |
+
result += res
|
63 |
+
|
64 |
+
if remove_duplicate:
|
65 |
+
result_new = []
|
66 |
+
id_list = []
|
67 |
+
for res in result:
|
68 |
+
if res[remove_duplicate] not in id_list:
|
69 |
+
id_list.append(res[remove_duplicate])
|
70 |
+
result_new.append(res)
|
71 |
+
result = result_new
|
72 |
+
|
73 |
+
json.dump(result,open(final_result_file,'w'))
|
74 |
+
print('result file saved to %s'%final_result_file)
|
75 |
+
|
76 |
+
return final_result_file
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
from pycocotools.coco import COCO
|
81 |
+
from pycocoevalcap.eval import COCOEvalCap
|
82 |
+
from torchvision.datasets.utils import download_url
|
83 |
+
|
84 |
+
def coco_caption_eval(coco_gt_root, results_file, split):
|
85 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
86 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
87 |
+
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
88 |
+
|
89 |
+
download_url(urls[split],coco_gt_root)
|
90 |
+
annotation_file = os.path.join(coco_gt_root,filenames[split])
|
91 |
+
|
92 |
+
# create coco object and coco_result object
|
93 |
+
coco = COCO(annotation_file)
|
94 |
+
coco_result = coco.loadRes(results_file)
|
95 |
+
|
96 |
+
# create coco_eval object by taking coco and coco_result
|
97 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
98 |
+
|
99 |
+
# evaluate on a subset of images by setting
|
100 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
101 |
+
# please remove this line when evaluating the full validation set
|
102 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
103 |
+
|
104 |
+
# evaluate results
|
105 |
+
# SPICE will take a few minutes the first time, but speeds up due to caching
|
106 |
+
coco_eval.evaluate()
|
107 |
+
|
108 |
+
# print output evaluation scores
|
109 |
+
for metric, score in coco_eval.eval.items():
|
110 |
+
print(f'{metric}: {score:.3f}')
|
111 |
+
|
112 |
+
return coco_eval
|
data/vqa_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from data.utils import pre_question
|
9 |
+
|
10 |
+
from torchvision.datasets.utils import download_url
|
11 |
+
|
12 |
+
class vqa_dataset(Dataset):
|
13 |
+
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
|
14 |
+
self.split = split
|
15 |
+
|
16 |
+
self.transform = transform
|
17 |
+
self.vqa_root = vqa_root
|
18 |
+
self.vg_root = vg_root
|
19 |
+
|
20 |
+
if split=='train':
|
21 |
+
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
|
22 |
+
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
|
23 |
+
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
|
24 |
+
|
25 |
+
self.annotation = []
|
26 |
+
for f in train_files:
|
27 |
+
download_url(urls[f],ann_root)
|
28 |
+
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
|
29 |
+
else:
|
30 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
|
31 |
+
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
|
32 |
+
|
33 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
|
34 |
+
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
|
35 |
+
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.annotation)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
|
42 |
+
ann = self.annotation[index]
|
43 |
+
|
44 |
+
if ann['dataset']=='vqa':
|
45 |
+
image_path = os.path.join(self.vqa_root,ann['image'])
|
46 |
+
elif ann['dataset']=='vg':
|
47 |
+
image_path = os.path.join(self.vg_root,ann['image'])
|
48 |
+
|
49 |
+
image = Image.open(image_path).convert('RGB')
|
50 |
+
image = self.transform(image)
|
51 |
+
|
52 |
+
if self.split == 'test':
|
53 |
+
question = pre_question(ann['question'])
|
54 |
+
question_id = ann['question_id']
|
55 |
+
return image, question, question_id
|
56 |
+
|
57 |
+
|
58 |
+
elif self.split=='train':
|
59 |
+
|
60 |
+
question = pre_question(ann['question'])
|
61 |
+
|
62 |
+
if ann['dataset']=='vqa':
|
63 |
+
answer_weight = {}
|
64 |
+
for answer in ann['answer']:
|
65 |
+
if answer in answer_weight.keys():
|
66 |
+
answer_weight[answer] += 1/len(ann['answer'])
|
67 |
+
else:
|
68 |
+
answer_weight[answer] = 1/len(ann['answer'])
|
69 |
+
|
70 |
+
answers = list(answer_weight.keys())
|
71 |
+
weights = list(answer_weight.values())
|
72 |
+
|
73 |
+
elif ann['dataset']=='vg':
|
74 |
+
answers = [ann['answer']]
|
75 |
+
weights = [0.2]
|
76 |
+
|
77 |
+
return image, question, answers, weights
|
78 |
+
|
79 |
+
|
80 |
+
def vqa_collate_fn(batch):
|
81 |
+
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
|
82 |
+
for image, question, answer, weights in batch:
|
83 |
+
image_list.append(image)
|
84 |
+
question_list.append(question)
|
85 |
+
weight_list += weights
|
86 |
+
answer_list += answer
|
87 |
+
n.append(len(answer))
|
88 |
+
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
|
eval_nocaps.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from data import create_dataset, create_sampler, create_loader
|
28 |
+
from data.utils import save_result
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluate(model, data_loader, device, config):
|
32 |
+
# evaluate
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
print_freq = 10
|
38 |
+
|
39 |
+
result = []
|
40 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
41 |
+
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
45 |
+
min_length=config['min_length'], repetition_penalty=1.1)
|
46 |
+
|
47 |
+
for caption, img_id in zip(captions, image_id):
|
48 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def main(args, config):
|
54 |
+
utils.init_distributed_mode(args)
|
55 |
+
|
56 |
+
device = torch.device(args.device)
|
57 |
+
|
58 |
+
# fix the seed for reproducibility
|
59 |
+
seed = args.seed + utils.get_rank()
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
random.seed(seed)
|
63 |
+
cudnn.benchmark = True
|
64 |
+
|
65 |
+
#### Dataset ####
|
66 |
+
print("Creating captioning dataset")
|
67 |
+
val_dataset, test_dataset = create_dataset('nocaps', config)
|
68 |
+
|
69 |
+
if args.distributed:
|
70 |
+
num_tasks = utils.get_world_size()
|
71 |
+
global_rank = utils.get_rank()
|
72 |
+
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
73 |
+
else:
|
74 |
+
samplers = [None,None]
|
75 |
+
|
76 |
+
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
77 |
+
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
78 |
+
is_trains=[False, False], collate_fns=[None,None])
|
79 |
+
|
80 |
+
#### Model ####
|
81 |
+
print("Creating model")
|
82 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
83 |
+
prompt=config['prompt'])
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
|
87 |
+
model_without_ddp = model
|
88 |
+
if args.distributed:
|
89 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
90 |
+
model_without_ddp = model.module
|
91 |
+
|
92 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
93 |
+
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
94 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
95 |
+
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
101 |
+
parser.add_argument('--output_dir', default='output/NoCaps')
|
102 |
+
parser.add_argument('--device', default='cuda')
|
103 |
+
parser.add_argument('--seed', default=42, type=int)
|
104 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
105 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
106 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
110 |
+
|
111 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
112 |
+
|
113 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
114 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
117 |
+
|
118 |
+
main(args, config)
|
models/__init__.py
ADDED
File without changes
|
models/blip.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
from models.vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
13 |
+
from transformers import BertTokenizer
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
import os
|
20 |
+
from urllib.parse import urlparse
|
21 |
+
from timm.models.hub import download_cached_file
|
22 |
+
|
23 |
+
class BLIP_Base(nn.Module):
|
24 |
+
def __init__(self,
|
25 |
+
med_config = 'configs/med_config.json',
|
26 |
+
image_size = 224,
|
27 |
+
vit = 'base',
|
28 |
+
vit_grad_ckpt = False,
|
29 |
+
vit_ckpt_layer = 0,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
34 |
+
image_size (int): input image size
|
35 |
+
vit (str): model size of vision transformer
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
40 |
+
self.tokenizer = init_tokenizer()
|
41 |
+
med_config = BertConfig.from_json_file(med_config)
|
42 |
+
med_config.encoder_width = vision_width
|
43 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, image, caption, mode):
|
47 |
+
|
48 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
49 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
50 |
+
|
51 |
+
if mode=='image':
|
52 |
+
# return image features
|
53 |
+
image_embeds = self.visual_encoder(image)
|
54 |
+
return image_embeds
|
55 |
+
|
56 |
+
elif mode=='text':
|
57 |
+
# return text features
|
58 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
59 |
+
return_dict = True, mode = 'text')
|
60 |
+
return text_output.last_hidden_state
|
61 |
+
|
62 |
+
elif mode=='multimodal':
|
63 |
+
# return multimodel features
|
64 |
+
image_embeds = self.visual_encoder(image)
|
65 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
66 |
+
|
67 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
68 |
+
output = self.text_encoder(text.input_ids,
|
69 |
+
attention_mask = text.attention_mask,
|
70 |
+
encoder_hidden_states = image_embeds,
|
71 |
+
encoder_attention_mask = image_atts,
|
72 |
+
return_dict = True,
|
73 |
+
)
|
74 |
+
return output.last_hidden_state
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
class BLIP_Decoder(nn.Module):
|
79 |
+
def __init__(self,
|
80 |
+
med_config = 'configs/med_config.json',
|
81 |
+
image_size = 384,
|
82 |
+
vit = 'base',
|
83 |
+
vit_grad_ckpt = False,
|
84 |
+
vit_ckpt_layer = 0,
|
85 |
+
prompt = 'a picture of ',
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
90 |
+
image_size (int): input image size
|
91 |
+
vit (str): model size of vision transformer
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
96 |
+
self.tokenizer = init_tokenizer()
|
97 |
+
med_config = BertConfig.from_json_file(med_config)
|
98 |
+
med_config.encoder_width = vision_width
|
99 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
100 |
+
|
101 |
+
self.prompt = prompt
|
102 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
103 |
+
|
104 |
+
|
105 |
+
def forward(self, image, caption):
|
106 |
+
|
107 |
+
image_embeds = self.visual_encoder(image)
|
108 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
109 |
+
|
110 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
111 |
+
|
112 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
113 |
+
|
114 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
115 |
+
decoder_targets[:,:self.prompt_length] = -100
|
116 |
+
|
117 |
+
decoder_output = self.text_decoder(text.input_ids,
|
118 |
+
attention_mask = text.attention_mask,
|
119 |
+
encoder_hidden_states = image_embeds,
|
120 |
+
encoder_attention_mask = image_atts,
|
121 |
+
labels = decoder_targets,
|
122 |
+
return_dict = True,
|
123 |
+
)
|
124 |
+
loss_lm = decoder_output.loss
|
125 |
+
|
126 |
+
return loss_lm
|
127 |
+
|
128 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
129 |
+
image_embeds = self.visual_encoder(image)
|
130 |
+
|
131 |
+
if not sample:
|
132 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
133 |
+
|
134 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
135 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
136 |
+
|
137 |
+
prompt = [self.prompt] * image.size(0)
|
138 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
139 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
140 |
+
input_ids = input_ids[:, :-1]
|
141 |
+
|
142 |
+
if sample:
|
143 |
+
#nucleus sampling
|
144 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
145 |
+
max_length=max_length,
|
146 |
+
min_length=min_length,
|
147 |
+
do_sample=True,
|
148 |
+
top_p=top_p,
|
149 |
+
num_return_sequences=1,
|
150 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
151 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
152 |
+
repetition_penalty=1.1,
|
153 |
+
**model_kwargs)
|
154 |
+
else:
|
155 |
+
#beam search
|
156 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
157 |
+
max_length=max_length,
|
158 |
+
min_length=min_length,
|
159 |
+
num_beams=num_beams,
|
160 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
161 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
162 |
+
repetition_penalty=repetition_penalty,
|
163 |
+
**model_kwargs)
|
164 |
+
|
165 |
+
captions = []
|
166 |
+
for output in outputs:
|
167 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
168 |
+
captions.append(caption[len(self.prompt):])
|
169 |
+
return captions
|
170 |
+
|
171 |
+
|
172 |
+
def blip_decoder(pretrained='',**kwargs):
|
173 |
+
model = BLIP_Decoder(**kwargs)
|
174 |
+
if pretrained:
|
175 |
+
model,msg = load_checkpoint(model,pretrained)
|
176 |
+
assert(len(msg.missing_keys)==0)
|
177 |
+
return model
|
178 |
+
|
179 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
180 |
+
model = BLIP_Base(**kwargs)
|
181 |
+
if pretrained:
|
182 |
+
model,msg = load_checkpoint(model,pretrained)
|
183 |
+
assert(len(msg.missing_keys)==0)
|
184 |
+
return model
|
185 |
+
|
186 |
+
def init_tokenizer():
|
187 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
188 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
189 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
190 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
191 |
+
return tokenizer
|
192 |
+
|
193 |
+
|
194 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
195 |
+
|
196 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
197 |
+
if vit=='base':
|
198 |
+
vision_width = 768
|
199 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
200 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
201 |
+
drop_path_rate=0 or drop_path_rate
|
202 |
+
)
|
203 |
+
elif vit=='large':
|
204 |
+
vision_width = 1024
|
205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
206 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
207 |
+
drop_path_rate=0.1 or drop_path_rate
|
208 |
+
)
|
209 |
+
return visual_encoder, vision_width
|
210 |
+
|
211 |
+
def is_url(url_or_filename):
|
212 |
+
parsed = urlparse(url_or_filename)
|
213 |
+
return parsed.scheme in ("http", "https")
|
214 |
+
|
215 |
+
def load_checkpoint(model,url_or_filename):
|
216 |
+
if is_url(url_or_filename):
|
217 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
218 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
219 |
+
elif os.path.isfile(url_or_filename):
|
220 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
221 |
+
else:
|
222 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
223 |
+
|
224 |
+
state_dict = checkpoint['model']
|
225 |
+
|
226 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
227 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
228 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
229 |
+
model.visual_encoder_m)
|
230 |
+
for key in model.state_dict().keys():
|
231 |
+
if key in state_dict.keys():
|
232 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
233 |
+
del state_dict[key]
|
234 |
+
|
235 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
236 |
+
print('load checkpoint from %s'%url_or_filename)
|
237 |
+
return model,msg
|
238 |
+
|
models/blip_nlvr.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.med import BertConfig
|
2 |
+
from models.nlvr_encoder import BertModel
|
3 |
+
from models.vit import interpolate_pos_embed
|
4 |
+
from models.blip import create_vit, init_tokenizer, is_url
|
5 |
+
|
6 |
+
from timm.models.hub import download_cached_file
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import BertTokenizer
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
class BLIP_NLVR(nn.Module):
|
15 |
+
def __init__(self,
|
16 |
+
med_config = 'configs/med_config.json',
|
17 |
+
image_size = 480,
|
18 |
+
vit = 'base',
|
19 |
+
vit_grad_ckpt = False,
|
20 |
+
vit_ckpt_layer = 0,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
25 |
+
image_size (int): input image size
|
26 |
+
vit (str): model size of vision transformer
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
31 |
+
self.tokenizer = init_tokenizer()
|
32 |
+
med_config = BertConfig.from_json_file(med_config)
|
33 |
+
med_config.encoder_width = vision_width
|
34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
35 |
+
|
36 |
+
self.cls_head = nn.Sequential(
|
37 |
+
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, image, text, targets, train=True):
|
43 |
+
|
44 |
+
image_embeds = self.visual_encoder(image)
|
45 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
46 |
+
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
47 |
+
|
48 |
+
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
|
49 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
50 |
+
|
51 |
+
output = self.text_encoder(text.input_ids,
|
52 |
+
attention_mask = text.attention_mask,
|
53 |
+
encoder_hidden_states = [image0_embeds,image1_embeds],
|
54 |
+
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
55 |
+
image_atts[image0_embeds.size(0):]],
|
56 |
+
return_dict = True,
|
57 |
+
)
|
58 |
+
hidden_state = output.last_hidden_state[:,0,:]
|
59 |
+
prediction = self.cls_head(hidden_state)
|
60 |
+
|
61 |
+
if train:
|
62 |
+
loss = F.cross_entropy(prediction, targets)
|
63 |
+
return loss
|
64 |
+
else:
|
65 |
+
return prediction
|
66 |
+
|
67 |
+
def blip_nlvr(pretrained='',**kwargs):
|
68 |
+
model = BLIP_NLVR(**kwargs)
|
69 |
+
if pretrained:
|
70 |
+
model,msg = load_checkpoint(model,pretrained)
|
71 |
+
print("missing keys:")
|
72 |
+
print(msg.missing_keys)
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def load_checkpoint(model,url_or_filename):
|
77 |
+
if is_url(url_or_filename):
|
78 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
79 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
80 |
+
elif os.path.isfile(url_or_filename):
|
81 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
82 |
+
else:
|
83 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
84 |
+
state_dict = checkpoint['model']
|
85 |
+
|
86 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
87 |
+
|
88 |
+
for key in list(state_dict.keys()):
|
89 |
+
if 'crossattention.self.' in key:
|
90 |
+
new_key0 = key.replace('self','self0')
|
91 |
+
new_key1 = key.replace('self','self1')
|
92 |
+
state_dict[new_key0] = state_dict[key]
|
93 |
+
state_dict[new_key1] = state_dict[key]
|
94 |
+
elif 'crossattention.output.dense.' in key:
|
95 |
+
new_key0 = key.replace('dense','dense0')
|
96 |
+
new_key1 = key.replace('dense','dense1')
|
97 |
+
state_dict[new_key0] = state_dict[key]
|
98 |
+
state_dict[new_key1] = state_dict[key]
|
99 |
+
|
100 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
101 |
+
print('load checkpoint from %s'%url_or_filename)
|
102 |
+
return model,msg
|
103 |
+
|
models/blip_pretrain.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
9 |
+
from transformers import BertTokenizer
|
10 |
+
import transformers
|
11 |
+
transformers.logging.set_verbosity_error()
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
18 |
+
|
19 |
+
class BLIP_Pretrain(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
med_config = 'configs/bert_config.json',
|
22 |
+
image_size = 224,
|
23 |
+
vit = 'base',
|
24 |
+
vit_grad_ckpt = False,
|
25 |
+
vit_ckpt_layer = 0,
|
26 |
+
embed_dim = 256,
|
27 |
+
queue_size = 57600,
|
28 |
+
momentum = 0.995,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
33 |
+
image_size (int): input image size
|
34 |
+
vit (str): model size of vision transformer
|
35 |
+
"""
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
39 |
+
|
40 |
+
if vit=='base':
|
41 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
42 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
43 |
+
map_location="cpu", check_hash=True)
|
44 |
+
state_dict = checkpoint["model"]
|
45 |
+
msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
|
46 |
+
elif vit=='large':
|
47 |
+
from timm.models.helpers import load_custom_pretrained
|
48 |
+
from timm.models.vision_transformer import default_cfgs
|
49 |
+
load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
|
50 |
+
|
51 |
+
self.tokenizer = init_tokenizer()
|
52 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
53 |
+
encoder_config.encoder_width = vision_width
|
54 |
+
self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
|
55 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
56 |
+
|
57 |
+
text_width = self.text_encoder.config.hidden_size
|
58 |
+
|
59 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
60 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
61 |
+
|
62 |
+
self.itm_head = nn.Linear(text_width, 2)
|
63 |
+
|
64 |
+
# create momentum encoders
|
65 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
66 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
67 |
+
self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
|
68 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
69 |
+
|
70 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
71 |
+
[self.vision_proj,self.vision_proj_m],
|
72 |
+
[self.text_encoder,self.text_encoder_m],
|
73 |
+
[self.text_proj,self.text_proj_m],
|
74 |
+
]
|
75 |
+
self.copy_params()
|
76 |
+
|
77 |
+
# create the queue
|
78 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
79 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
80 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
81 |
+
|
82 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
83 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
84 |
+
|
85 |
+
self.queue_size = queue_size
|
86 |
+
self.momentum = momentum
|
87 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
88 |
+
|
89 |
+
# create the decoder
|
90 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
91 |
+
decoder_config.encoder_width = vision_width
|
92 |
+
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
93 |
+
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
94 |
+
tie_encoder_decoder_weights(self.text_decoder.bert,self.text_encoder,'','/attention')
|
95 |
+
|
96 |
+
|
97 |
+
def forward(self, image, caption, alpha):
|
98 |
+
with torch.no_grad():
|
99 |
+
self.temp.clamp_(0.001,0.5)
|
100 |
+
|
101 |
+
image_embeds = self.visual_encoder(image)
|
102 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
103 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
104 |
+
|
105 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
|
106 |
+
return_tensors="pt").to(image.device)
|
107 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
108 |
+
return_dict = True, mode = 'text')
|
109 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
110 |
+
|
111 |
+
# get momentum features
|
112 |
+
with torch.no_grad():
|
113 |
+
self._momentum_update()
|
114 |
+
image_embeds_m = self.visual_encoder_m(image)
|
115 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
116 |
+
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
117 |
+
|
118 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
119 |
+
return_dict = True, mode = 'text')
|
120 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
121 |
+
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
122 |
+
|
123 |
+
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
|
124 |
+
sim_t2i_m = text_feat_m @ image_feat_all / self.temp
|
125 |
+
|
126 |
+
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
|
127 |
+
sim_targets.fill_diagonal_(1)
|
128 |
+
|
129 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
130 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
131 |
+
|
132 |
+
sim_i2t = image_feat @ text_feat_all / self.temp
|
133 |
+
sim_t2i = text_feat @ image_feat_all / self.temp
|
134 |
+
|
135 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
136 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
137 |
+
|
138 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
139 |
+
|
140 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
|
141 |
+
|
142 |
+
###============== Image-text Matching ===================###
|
143 |
+
encoder_input_ids = text.input_ids.clone()
|
144 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
145 |
+
|
146 |
+
# forward the positve image-text pair
|
147 |
+
bs = image.size(0)
|
148 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
149 |
+
attention_mask = text.attention_mask,
|
150 |
+
encoder_hidden_states = image_embeds,
|
151 |
+
encoder_attention_mask = image_atts,
|
152 |
+
return_dict = True,
|
153 |
+
)
|
154 |
+
with torch.no_grad():
|
155 |
+
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
|
156 |
+
weights_t2i.fill_diagonal_(0)
|
157 |
+
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
|
158 |
+
weights_i2t.fill_diagonal_(0)
|
159 |
+
|
160 |
+
# select a negative image for each text
|
161 |
+
image_embeds_neg = []
|
162 |
+
for b in range(bs):
|
163 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
164 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
165 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
166 |
+
|
167 |
+
# select a negative text for each image
|
168 |
+
text_ids_neg = []
|
169 |
+
text_atts_neg = []
|
170 |
+
for b in range(bs):
|
171 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
172 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
173 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
174 |
+
|
175 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
176 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
177 |
+
|
178 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
179 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
180 |
+
|
181 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
182 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
183 |
+
|
184 |
+
output_neg = self.text_encoder(text_ids_all,
|
185 |
+
attention_mask = text_atts_all,
|
186 |
+
encoder_hidden_states = image_embeds_all,
|
187 |
+
encoder_attention_mask = image_atts_all,
|
188 |
+
return_dict = True,
|
189 |
+
)
|
190 |
+
|
191 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
192 |
+
vl_output = self.itm_head(vl_embeddings)
|
193 |
+
|
194 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
195 |
+
dim=0).to(image.device)
|
196 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
197 |
+
|
198 |
+
##================= LM ========================##
|
199 |
+
decoder_input_ids = text.input_ids.clone()
|
200 |
+
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
|
201 |
+
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
|
202 |
+
|
203 |
+
decoder_output = self.text_decoder(decoder_input_ids,
|
204 |
+
attention_mask = text.attention_mask,
|
205 |
+
encoder_hidden_states = image_embeds,
|
206 |
+
encoder_attention_mask = image_atts,
|
207 |
+
labels = decoder_targets,
|
208 |
+
return_dict = True,
|
209 |
+
)
|
210 |
+
|
211 |
+
loss_lm = decoder_output.loss
|
212 |
+
return loss_ita, loss_itm, loss_lm
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def copy_params(self):
|
218 |
+
for model_pair in self.model_pairs:
|
219 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
220 |
+
param_m.data.copy_(param.data) # initialize
|
221 |
+
param_m.requires_grad = False # not update by gradient
|
222 |
+
|
223 |
+
|
224 |
+
@torch.no_grad()
|
225 |
+
def _momentum_update(self):
|
226 |
+
for model_pair in self.model_pairs:
|
227 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
228 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
229 |
+
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat):
|
233 |
+
# gather keys before updating queue
|
234 |
+
image_feats = concat_all_gather(image_feat)
|
235 |
+
text_feats = concat_all_gather(text_feat)
|
236 |
+
|
237 |
+
batch_size = image_feats.shape[0]
|
238 |
+
|
239 |
+
ptr = int(self.queue_ptr)
|
240 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
241 |
+
|
242 |
+
# replace the keys at ptr (dequeue and enqueue)
|
243 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
244 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
245 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
246 |
+
|
247 |
+
self.queue_ptr[0] = ptr
|
248 |
+
|
249 |
+
|
250 |
+
def blip_pretrain(**kwargs):
|
251 |
+
model = BLIP_Pretrain(**kwargs)
|
252 |
+
return model
|
253 |
+
|
254 |
+
|
255 |
+
@torch.no_grad()
|
256 |
+
def concat_all_gather(tensor):
|
257 |
+
"""
|
258 |
+
Performs all_gather operation on the provided tensors.
|
259 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
260 |
+
"""
|
261 |
+
tensors_gather = [torch.ones_like(tensor)
|
262 |
+
for _ in range(torch.distributed.get_world_size())]
|
263 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
264 |
+
|
265 |
+
output = torch.cat(tensors_gather, dim=0)
|
266 |
+
return output
|
267 |
+
|
268 |
+
|
269 |
+
from typing import List
|
270 |
+
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
|
271 |
+
uninitialized_encoder_weights: List[str] = []
|
272 |
+
if decoder.__class__ != encoder.__class__:
|
273 |
+
logger.info(
|
274 |
+
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
275 |
+
)
|
276 |
+
|
277 |
+
def tie_encoder_to_decoder_recursively(
|
278 |
+
decoder_pointer: nn.Module,
|
279 |
+
encoder_pointer: nn.Module,
|
280 |
+
module_name: str,
|
281 |
+
uninitialized_encoder_weights: List[str],
|
282 |
+
skip_key: str,
|
283 |
+
depth=0,
|
284 |
+
):
|
285 |
+
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
286 |
+
encoder_pointer, nn.Module
|
287 |
+
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
|
288 |
+
if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
|
289 |
+
assert hasattr(encoder_pointer, "weight")
|
290 |
+
encoder_pointer.weight = decoder_pointer.weight
|
291 |
+
if hasattr(decoder_pointer, "bias"):
|
292 |
+
assert hasattr(encoder_pointer, "bias")
|
293 |
+
encoder_pointer.bias = decoder_pointer.bias
|
294 |
+
print(module_name+' is tied')
|
295 |
+
return
|
296 |
+
|
297 |
+
encoder_modules = encoder_pointer._modules
|
298 |
+
decoder_modules = decoder_pointer._modules
|
299 |
+
if len(decoder_modules) > 0:
|
300 |
+
assert (
|
301 |
+
len(encoder_modules) > 0
|
302 |
+
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
303 |
+
|
304 |
+
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
305 |
+
encoder_layer_pos = 0
|
306 |
+
for name, module in decoder_modules.items():
|
307 |
+
if name.isdigit():
|
308 |
+
encoder_name = str(int(name) + encoder_layer_pos)
|
309 |
+
decoder_name = name
|
310 |
+
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
|
311 |
+
encoder_modules
|
312 |
+
) != len(decoder_modules):
|
313 |
+
# this can happen if the name corresponds to the position in a list module list of layers
|
314 |
+
# in this case the decoder has added a cross-attention that the encoder does not have
|
315 |
+
# thus skip this step and subtract one layer pos from encoder
|
316 |
+
encoder_layer_pos -= 1
|
317 |
+
continue
|
318 |
+
elif name not in encoder_modules:
|
319 |
+
continue
|
320 |
+
elif depth > 500:
|
321 |
+
raise ValueError(
|
322 |
+
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
decoder_name = encoder_name = name
|
326 |
+
tie_encoder_to_decoder_recursively(
|
327 |
+
decoder_modules[decoder_name],
|
328 |
+
encoder_modules[encoder_name],
|
329 |
+
module_name + "/" + name,
|
330 |
+
uninitialized_encoder_weights,
|
331 |
+
skip_key,
|
332 |
+
depth=depth + 1,
|
333 |
+
)
|
334 |
+
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
335 |
+
|
336 |
+
uninitialized_encoder_weights += list(all_encoder_weights)
|
337 |
+
|
338 |
+
# tie weights recursively
|
339 |
+
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
|
models/blip_retrieval.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.med import BertConfig, BertModel
|
2 |
+
from transformers import BertTokenizer
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
9 |
+
|
10 |
+
class BLIP_Retrieval(nn.Module):
|
11 |
+
def __init__(self,
|
12 |
+
med_config = 'configs/med_config.json',
|
13 |
+
image_size = 384,
|
14 |
+
vit = 'base',
|
15 |
+
vit_grad_ckpt = False,
|
16 |
+
vit_ckpt_layer = 0,
|
17 |
+
embed_dim = 256,
|
18 |
+
queue_size = 57600,
|
19 |
+
momentum = 0.995,
|
20 |
+
negative_all_rank = False,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
25 |
+
image_size (int): input image size
|
26 |
+
vit (str): model size of vision transformer
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
31 |
+
self.tokenizer = init_tokenizer()
|
32 |
+
med_config = BertConfig.from_json_file(med_config)
|
33 |
+
med_config.encoder_width = vision_width
|
34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
35 |
+
|
36 |
+
text_width = self.text_encoder.config.hidden_size
|
37 |
+
|
38 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
39 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
40 |
+
|
41 |
+
self.itm_head = nn.Linear(text_width, 2)
|
42 |
+
|
43 |
+
# create momentum encoders
|
44 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
45 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
46 |
+
self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
|
47 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
48 |
+
|
49 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
50 |
+
[self.vision_proj,self.vision_proj_m],
|
51 |
+
[self.text_encoder,self.text_encoder_m],
|
52 |
+
[self.text_proj,self.text_proj_m],
|
53 |
+
]
|
54 |
+
self.copy_params()
|
55 |
+
|
56 |
+
# create the queue
|
57 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
58 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
59 |
+
self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
|
60 |
+
self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
|
61 |
+
|
62 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
63 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
64 |
+
|
65 |
+
self.queue_size = queue_size
|
66 |
+
self.momentum = momentum
|
67 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
68 |
+
|
69 |
+
self.negative_all_rank = negative_all_rank
|
70 |
+
|
71 |
+
|
72 |
+
def forward(self, image, caption, alpha, idx):
|
73 |
+
with torch.no_grad():
|
74 |
+
self.temp.clamp_(0.001,0.5)
|
75 |
+
|
76 |
+
image_embeds = self.visual_encoder(image)
|
77 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
78 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
79 |
+
|
80 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
81 |
+
return_tensors="pt").to(image.device)
|
82 |
+
|
83 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
84 |
+
return_dict = True, mode = 'text')
|
85 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
86 |
+
|
87 |
+
###============== Image-text Contrastive Learning ===================###
|
88 |
+
idx = idx.view(-1,1)
|
89 |
+
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
|
90 |
+
pos_idx = torch.eq(idx, idx_all).float()
|
91 |
+
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
|
92 |
+
|
93 |
+
# get momentum features
|
94 |
+
with torch.no_grad():
|
95 |
+
self._momentum_update()
|
96 |
+
image_embeds_m = self.visual_encoder_m(image)
|
97 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
98 |
+
image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
99 |
+
|
100 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
101 |
+
return_dict = True, mode = 'text')
|
102 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
103 |
+
text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
104 |
+
|
105 |
+
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
|
106 |
+
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
|
107 |
+
|
108 |
+
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
|
109 |
+
sim_targets.fill_diagonal_(1)
|
110 |
+
|
111 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
112 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
113 |
+
|
114 |
+
sim_i2t = image_feat @ text_feat_m_all / self.temp
|
115 |
+
sim_t2i = text_feat @ image_feat_m_all / self.temp
|
116 |
+
|
117 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
118 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
119 |
+
|
120 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
121 |
+
|
122 |
+
idxs = concat_all_gather(idx)
|
123 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
|
124 |
+
|
125 |
+
###============== Image-text Matching ===================###
|
126 |
+
encoder_input_ids = text.input_ids.clone()
|
127 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
128 |
+
|
129 |
+
# forward the positve image-text pair
|
130 |
+
bs = image.size(0)
|
131 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
132 |
+
attention_mask = text.attention_mask,
|
133 |
+
encoder_hidden_states = image_embeds,
|
134 |
+
encoder_attention_mask = image_atts,
|
135 |
+
return_dict = True,
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
if self.negative_all_rank:
|
140 |
+
# compute sample similarity
|
141 |
+
with torch.no_grad():
|
142 |
+
mask = torch.eq(idx, idxs.t())
|
143 |
+
|
144 |
+
image_feat_world = concat_all_gather(image_feat)
|
145 |
+
text_feat_world = concat_all_gather(text_feat)
|
146 |
+
|
147 |
+
sim_i2t = image_feat @ text_feat_world.t() / self.temp
|
148 |
+
sim_t2i = text_feat @ image_feat_world.t() / self.temp
|
149 |
+
|
150 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
151 |
+
weights_i2t.masked_fill_(mask, 0)
|
152 |
+
|
153 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
154 |
+
weights_t2i.masked_fill_(mask, 0)
|
155 |
+
|
156 |
+
image_embeds_world = all_gather_with_grad(image_embeds)
|
157 |
+
|
158 |
+
# select a negative image (from all ranks) for each text
|
159 |
+
image_embeds_neg = []
|
160 |
+
for b in range(bs):
|
161 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
162 |
+
image_embeds_neg.append(image_embeds_world[neg_idx])
|
163 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
164 |
+
|
165 |
+
# select a negative text (from all ranks) for each image
|
166 |
+
input_ids_world = concat_all_gather(encoder_input_ids)
|
167 |
+
att_mask_world = concat_all_gather(text.attention_mask)
|
168 |
+
|
169 |
+
text_ids_neg = []
|
170 |
+
text_atts_neg = []
|
171 |
+
for b in range(bs):
|
172 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
173 |
+
text_ids_neg.append(input_ids_world[neg_idx])
|
174 |
+
text_atts_neg.append(att_mask_world[neg_idx])
|
175 |
+
|
176 |
+
else:
|
177 |
+
with torch.no_grad():
|
178 |
+
mask = torch.eq(idx, idx.t())
|
179 |
+
|
180 |
+
sim_i2t = image_feat @ text_feat.t() / self.temp
|
181 |
+
sim_t2i = text_feat @ image_feat.t() / self.temp
|
182 |
+
|
183 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
184 |
+
weights_i2t.masked_fill_(mask, 0)
|
185 |
+
|
186 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
187 |
+
weights_t2i.masked_fill_(mask, 0)
|
188 |
+
|
189 |
+
# select a negative image (from same rank) for each text
|
190 |
+
image_embeds_neg = []
|
191 |
+
for b in range(bs):
|
192 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
193 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
194 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
195 |
+
|
196 |
+
# select a negative text (from same rank) for each image
|
197 |
+
text_ids_neg = []
|
198 |
+
text_atts_neg = []
|
199 |
+
for b in range(bs):
|
200 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
201 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
202 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
203 |
+
|
204 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
205 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
206 |
+
|
207 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
208 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
209 |
+
|
210 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
211 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
212 |
+
|
213 |
+
output_neg = self.text_encoder(text_ids_all,
|
214 |
+
attention_mask = text_atts_all,
|
215 |
+
encoder_hidden_states = image_embeds_all,
|
216 |
+
encoder_attention_mask = image_atts_all,
|
217 |
+
return_dict = True,
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
222 |
+
vl_output = self.itm_head(vl_embeddings)
|
223 |
+
|
224 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
225 |
+
dim=0).to(image.device)
|
226 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
227 |
+
|
228 |
+
return loss_ita, loss_itm
|
229 |
+
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def copy_params(self):
|
233 |
+
for model_pair in self.model_pairs:
|
234 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
235 |
+
param_m.data.copy_(param.data) # initialize
|
236 |
+
param_m.requires_grad = False # not update by gradient
|
237 |
+
|
238 |
+
|
239 |
+
@torch.no_grad()
|
240 |
+
def _momentum_update(self):
|
241 |
+
for model_pair in self.model_pairs:
|
242 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
243 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
244 |
+
|
245 |
+
|
246 |
+
@torch.no_grad()
|
247 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
|
248 |
+
# gather keys before updating queue
|
249 |
+
image_feats = concat_all_gather(image_feat)
|
250 |
+
text_feats = concat_all_gather(text_feat)
|
251 |
+
|
252 |
+
|
253 |
+
batch_size = image_feats.shape[0]
|
254 |
+
|
255 |
+
ptr = int(self.ptr_queue)
|
256 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
257 |
+
|
258 |
+
# replace the keys at ptr (dequeue and enqueue)
|
259 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
260 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
261 |
+
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
|
262 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
263 |
+
|
264 |
+
self.ptr_queue[0] = ptr
|
265 |
+
|
266 |
+
|
267 |
+
def blip_retrieval(pretrained='',**kwargs):
|
268 |
+
model = BLIP_Retrieval(**kwargs)
|
269 |
+
if pretrained:
|
270 |
+
model,msg = load_checkpoint(model,pretrained)
|
271 |
+
print("missing keys:")
|
272 |
+
print(msg.missing_keys)
|
273 |
+
return model
|
274 |
+
|
275 |
+
|
276 |
+
@torch.no_grad()
|
277 |
+
def concat_all_gather(tensor):
|
278 |
+
"""
|
279 |
+
Performs all_gather operation on the provided tensors.
|
280 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
281 |
+
"""
|
282 |
+
tensors_gather = [torch.ones_like(tensor)
|
283 |
+
for _ in range(torch.distributed.get_world_size())]
|
284 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
285 |
+
|
286 |
+
output = torch.cat(tensors_gather, dim=0)
|
287 |
+
return output
|
288 |
+
|
289 |
+
|
290 |
+
class GatherLayer(torch.autograd.Function):
|
291 |
+
"""
|
292 |
+
Gather tensors from all workers with support for backward propagation:
|
293 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
294 |
+
"""
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def forward(ctx, x):
|
298 |
+
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
|
299 |
+
torch.distributed.all_gather(output, x)
|
300 |
+
return tuple(output)
|
301 |
+
|
302 |
+
@staticmethod
|
303 |
+
def backward(ctx, *grads):
|
304 |
+
all_gradients = torch.stack(grads)
|
305 |
+
torch.distributed.all_reduce(all_gradients)
|
306 |
+
return all_gradients[torch.distributed.get_rank()]
|
307 |
+
|
308 |
+
|
309 |
+
def all_gather_with_grad(tensors):
|
310 |
+
"""
|
311 |
+
Performs all_gather operation on the provided tensors.
|
312 |
+
Graph remains connected for backward grad computation.
|
313 |
+
"""
|
314 |
+
# Queue the gathered tensors
|
315 |
+
world_size = torch.distributed.get_world_size()
|
316 |
+
# There is no need for reduction in the single-proc case
|
317 |
+
if world_size == 1:
|
318 |
+
return tensors
|
319 |
+
|
320 |
+
tensor_all = GatherLayer.apply(tensors)
|
321 |
+
|
322 |
+
return torch.cat(tensor_all, dim=0)
|
models/blip_vqa.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
2 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class BLIP_VQA(nn.Module):
|
11 |
+
def __init__(self,
|
12 |
+
med_config = 'configs/med_config.json',
|
13 |
+
image_size = 480,
|
14 |
+
vit = 'base',
|
15 |
+
vit_grad_ckpt = False,
|
16 |
+
vit_ckpt_layer = 0,
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
21 |
+
image_size (int): input image size
|
22 |
+
vit (str): model size of vision transformer
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
27 |
+
self.tokenizer = init_tokenizer()
|
28 |
+
|
29 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
30 |
+
encoder_config.encoder_width = vision_width
|
31 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
32 |
+
|
33 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
34 |
+
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
35 |
+
|
36 |
+
|
37 |
+
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
38 |
+
|
39 |
+
image_embeds = self.visual_encoder(image)
|
40 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
41 |
+
|
42 |
+
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
|
43 |
+
return_tensors="pt").to(image.device)
|
44 |
+
question.input_ids[:,0] = self.tokenizer.enc_token_id
|
45 |
+
|
46 |
+
if train:
|
47 |
+
'''
|
48 |
+
n: number of answers for each question
|
49 |
+
weights: weight for each answer
|
50 |
+
'''
|
51 |
+
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
|
52 |
+
answer.input_ids[:,0] = self.tokenizer.bos_token_id
|
53 |
+
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
54 |
+
|
55 |
+
question_output = self.text_encoder(question.input_ids,
|
56 |
+
attention_mask = question.attention_mask,
|
57 |
+
encoder_hidden_states = image_embeds,
|
58 |
+
encoder_attention_mask = image_atts,
|
59 |
+
return_dict = True)
|
60 |
+
|
61 |
+
question_states = []
|
62 |
+
question_atts = []
|
63 |
+
for b, n in enumerate(n):
|
64 |
+
question_states += [question_output.last_hidden_state[b]]*n
|
65 |
+
question_atts += [question.attention_mask[b]]*n
|
66 |
+
question_states = torch.stack(question_states,0)
|
67 |
+
question_atts = torch.stack(question_atts,0)
|
68 |
+
|
69 |
+
answer_output = self.text_decoder(answer.input_ids,
|
70 |
+
attention_mask = answer.attention_mask,
|
71 |
+
encoder_hidden_states = question_states,
|
72 |
+
encoder_attention_mask = question_atts,
|
73 |
+
labels = answer_targets,
|
74 |
+
return_dict = True,
|
75 |
+
reduction = 'none',
|
76 |
+
)
|
77 |
+
|
78 |
+
loss = weights * answer_output.loss
|
79 |
+
loss = loss.sum()/image.size(0)
|
80 |
+
|
81 |
+
return loss
|
82 |
+
|
83 |
+
|
84 |
+
else:
|
85 |
+
question_output = self.text_encoder(question.input_ids,
|
86 |
+
attention_mask = question.attention_mask,
|
87 |
+
encoder_hidden_states = image_embeds,
|
88 |
+
encoder_attention_mask = image_atts,
|
89 |
+
return_dict = True)
|
90 |
+
|
91 |
+
if inference=='generate':
|
92 |
+
num_beams = 3
|
93 |
+
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
|
94 |
+
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
|
95 |
+
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
|
96 |
+
|
97 |
+
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
98 |
+
|
99 |
+
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
100 |
+
max_length=10,
|
101 |
+
min_length=1,
|
102 |
+
num_beams=num_beams,
|
103 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
104 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
105 |
+
**model_kwargs)
|
106 |
+
|
107 |
+
answers = []
|
108 |
+
for output in outputs:
|
109 |
+
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
110 |
+
answers.append(answer)
|
111 |
+
return answers
|
112 |
+
|
113 |
+
elif inference=='rank':
|
114 |
+
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
|
115 |
+
answer.input_ids, answer.attention_mask, k_test)
|
116 |
+
return max_ids
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
121 |
+
|
122 |
+
num_ques = question_states.size(0)
|
123 |
+
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
124 |
+
|
125 |
+
start_output = self.text_decoder(start_ids,
|
126 |
+
encoder_hidden_states = question_states,
|
127 |
+
encoder_attention_mask = question_atts,
|
128 |
+
return_dict = True,
|
129 |
+
reduction = 'none')
|
130 |
+
logits = start_output.logits[:,0,:] # first token's logit
|
131 |
+
|
132 |
+
# topk_probs: top-k probability
|
133 |
+
# topk_ids: [num_question, k]
|
134 |
+
answer_first_token = answer_ids[:,1]
|
135 |
+
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
136 |
+
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
137 |
+
|
138 |
+
# answer input: [num_question*k, answer_len]
|
139 |
+
input_ids = []
|
140 |
+
input_atts = []
|
141 |
+
for b, topk_id in enumerate(topk_ids):
|
142 |
+
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
143 |
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
144 |
+
input_ids = torch.cat(input_ids,dim=0)
|
145 |
+
input_atts = torch.cat(input_atts,dim=0)
|
146 |
+
|
147 |
+
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
148 |
+
|
149 |
+
# repeat encoder's output for top-k answers
|
150 |
+
question_states = tile(question_states, 0, k)
|
151 |
+
question_atts = tile(question_atts, 0, k)
|
152 |
+
|
153 |
+
output = self.text_decoder(input_ids,
|
154 |
+
attention_mask = input_atts,
|
155 |
+
encoder_hidden_states = question_states,
|
156 |
+
encoder_attention_mask = question_atts,
|
157 |
+
labels = targets_ids,
|
158 |
+
return_dict = True,
|
159 |
+
reduction = 'none')
|
160 |
+
|
161 |
+
log_probs_sum = -output.loss
|
162 |
+
log_probs_sum = log_probs_sum.view(num_ques,k)
|
163 |
+
|
164 |
+
max_topk_ids = log_probs_sum.argmax(dim=1)
|
165 |
+
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
|
166 |
+
|
167 |
+
return max_ids
|
168 |
+
|
169 |
+
|
170 |
+
def blip_vqa(pretrained='',**kwargs):
|
171 |
+
model = BLIP_VQA(**kwargs)
|
172 |
+
if pretrained:
|
173 |
+
model,msg = load_checkpoint(model,pretrained)
|
174 |
+
# assert(len(msg.missing_keys)==0)
|
175 |
+
return model
|
176 |
+
|
177 |
+
|
178 |
+
def tile(x, dim, n_tile):
|
179 |
+
init_dim = x.size(dim)
|
180 |
+
repeat_idx = [1] * x.dim()
|
181 |
+
repeat_idx[dim] = n_tile
|
182 |
+
x = x.repeat(*(repeat_idx))
|
183 |
+
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
184 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
185 |
+
|
186 |
+
|
models/med.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
'''
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
59 |
+
|
60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
61 |
+
# any TensorFlow checkpoint file
|
62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
64 |
+
|
65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
68 |
+
|
69 |
+
self.config = config
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
73 |
+
):
|
74 |
+
if input_ids is not None:
|
75 |
+
input_shape = input_ids.size()
|
76 |
+
else:
|
77 |
+
input_shape = inputs_embeds.size()[:-1]
|
78 |
+
|
79 |
+
seq_length = input_shape[1]
|
80 |
+
|
81 |
+
if position_ids is None:
|
82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
83 |
+
|
84 |
+
if inputs_embeds is None:
|
85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
86 |
+
|
87 |
+
embeddings = inputs_embeds
|
88 |
+
|
89 |
+
if self.position_embedding_type == "absolute":
|
90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
91 |
+
embeddings += position_embeddings
|
92 |
+
embeddings = self.LayerNorm(embeddings)
|
93 |
+
embeddings = self.dropout(embeddings)
|
94 |
+
return embeddings
|
95 |
+
|
96 |
+
|
97 |
+
class BertSelfAttention(nn.Module):
|
98 |
+
def __init__(self, config, is_cross_attention):
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
102 |
+
raise ValueError(
|
103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.num_attention_heads = config.num_attention_heads
|
108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
110 |
+
|
111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
112 |
+
if is_cross_attention:
|
113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
115 |
+
else:
|
116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
118 |
+
|
119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
124 |
+
self.save_attention = False
|
125 |
+
|
126 |
+
def save_attn_gradients(self, attn_gradients):
|
127 |
+
self.attn_gradients = attn_gradients
|
128 |
+
|
129 |
+
def get_attn_gradients(self):
|
130 |
+
return self.attn_gradients
|
131 |
+
|
132 |
+
def save_attention_map(self, attention_map):
|
133 |
+
self.attention_map = attention_map
|
134 |
+
|
135 |
+
def get_attention_map(self):
|
136 |
+
return self.attention_map
|
137 |
+
|
138 |
+
def transpose_for_scores(self, x):
|
139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
140 |
+
x = x.view(*new_x_shape)
|
141 |
+
return x.permute(0, 2, 1, 3)
|
142 |
+
|
143 |
+
def forward(
|
144 |
+
self,
|
145 |
+
hidden_states,
|
146 |
+
attention_mask=None,
|
147 |
+
head_mask=None,
|
148 |
+
encoder_hidden_states=None,
|
149 |
+
encoder_attention_mask=None,
|
150 |
+
past_key_value=None,
|
151 |
+
output_attentions=False,
|
152 |
+
):
|
153 |
+
mixed_query_layer = self.query(hidden_states)
|
154 |
+
|
155 |
+
# If this is instantiated as a cross-attention module, the keys
|
156 |
+
# and values come from an encoder; the attention mask needs to be
|
157 |
+
# such that the encoder's padding tokens are not attended to.
|
158 |
+
is_cross_attention = encoder_hidden_states is not None
|
159 |
+
|
160 |
+
if is_cross_attention:
|
161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
163 |
+
attention_mask = encoder_attention_mask
|
164 |
+
elif past_key_value is not None:
|
165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
169 |
+
else:
|
170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
172 |
+
|
173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
174 |
+
|
175 |
+
past_key_value = (key_layer, value_layer)
|
176 |
+
|
177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
181 |
+
seq_length = hidden_states.size()[1]
|
182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
184 |
+
distance = position_ids_l - position_ids_r
|
185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
187 |
+
|
188 |
+
if self.position_embedding_type == "relative_key":
|
189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
190 |
+
attention_scores = attention_scores + relative_position_scores
|
191 |
+
elif self.position_embedding_type == "relative_key_query":
|
192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
195 |
+
|
196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
197 |
+
if attention_mask is not None:
|
198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
199 |
+
attention_scores = attention_scores + attention_mask
|
200 |
+
|
201 |
+
# Normalize the attention scores to probabilities.
|
202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
203 |
+
|
204 |
+
if is_cross_attention and self.save_attention:
|
205 |
+
self.save_attention_map(attention_probs)
|
206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
207 |
+
|
208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
211 |
+
|
212 |
+
# Mask heads if we want to
|
213 |
+
if head_mask is not None:
|
214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
215 |
+
|
216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
217 |
+
|
218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
221 |
+
|
222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
223 |
+
|
224 |
+
outputs = outputs + (past_key_value,)
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
|
228 |
+
class BertSelfOutput(nn.Module):
|
229 |
+
def __init__(self, config):
|
230 |
+
super().__init__()
|
231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
234 |
+
|
235 |
+
def forward(self, hidden_states, input_tensor):
|
236 |
+
hidden_states = self.dense(hidden_states)
|
237 |
+
hidden_states = self.dropout(hidden_states)
|
238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
239 |
+
return hidden_states
|
240 |
+
|
241 |
+
|
242 |
+
class BertAttention(nn.Module):
|
243 |
+
def __init__(self, config, is_cross_attention=False):
|
244 |
+
super().__init__()
|
245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
246 |
+
self.output = BertSelfOutput(config)
|
247 |
+
self.pruned_heads = set()
|
248 |
+
|
249 |
+
def prune_heads(self, heads):
|
250 |
+
if len(heads) == 0:
|
251 |
+
return
|
252 |
+
heads, index = find_pruneable_heads_and_indices(
|
253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
254 |
+
)
|
255 |
+
|
256 |
+
# Prune linear layers
|
257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
261 |
+
|
262 |
+
# Update hyper params and store pruned heads
|
263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states,
|
270 |
+
attention_mask=None,
|
271 |
+
head_mask=None,
|
272 |
+
encoder_hidden_states=None,
|
273 |
+
encoder_attention_mask=None,
|
274 |
+
past_key_value=None,
|
275 |
+
output_attentions=False,
|
276 |
+
):
|
277 |
+
self_outputs = self.self(
|
278 |
+
hidden_states,
|
279 |
+
attention_mask,
|
280 |
+
head_mask,
|
281 |
+
encoder_hidden_states,
|
282 |
+
encoder_attention_mask,
|
283 |
+
past_key_value,
|
284 |
+
output_attentions,
|
285 |
+
)
|
286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
288 |
+
return outputs
|
289 |
+
|
290 |
+
|
291 |
+
class BertIntermediate(nn.Module):
|
292 |
+
def __init__(self, config):
|
293 |
+
super().__init__()
|
294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
295 |
+
if isinstance(config.hidden_act, str):
|
296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
297 |
+
else:
|
298 |
+
self.intermediate_act_fn = config.hidden_act
|
299 |
+
|
300 |
+
def forward(self, hidden_states):
|
301 |
+
hidden_states = self.dense(hidden_states)
|
302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
303 |
+
return hidden_states
|
304 |
+
|
305 |
+
|
306 |
+
class BertOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super().__init__()
|
309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
318 |
+
|
319 |
+
|
320 |
+
class BertLayer(nn.Module):
|
321 |
+
def __init__(self, config, layer_num):
|
322 |
+
super().__init__()
|
323 |
+
self.config = config
|
324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
325 |
+
self.seq_len_dim = 1
|
326 |
+
self.attention = BertAttention(config)
|
327 |
+
self.layer_num = layer_num
|
328 |
+
if self.config.add_cross_attention:
|
329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
330 |
+
self.intermediate = BertIntermediate(config)
|
331 |
+
self.output = BertOutput(config)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=None,
|
337 |
+
head_mask=None,
|
338 |
+
encoder_hidden_states=None,
|
339 |
+
encoder_attention_mask=None,
|
340 |
+
past_key_value=None,
|
341 |
+
output_attentions=False,
|
342 |
+
mode=None,
|
343 |
+
):
|
344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
346 |
+
self_attention_outputs = self.attention(
|
347 |
+
hidden_states,
|
348 |
+
attention_mask,
|
349 |
+
head_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
past_key_value=self_attn_past_key_value,
|
352 |
+
)
|
353 |
+
attention_output = self_attention_outputs[0]
|
354 |
+
|
355 |
+
outputs = self_attention_outputs[1:-1]
|
356 |
+
present_key_value = self_attention_outputs[-1]
|
357 |
+
|
358 |
+
if mode=='multimodal':
|
359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
360 |
+
|
361 |
+
cross_attention_outputs = self.crossattention(
|
362 |
+
attention_output,
|
363 |
+
attention_mask,
|
364 |
+
head_mask,
|
365 |
+
encoder_hidden_states,
|
366 |
+
encoder_attention_mask,
|
367 |
+
output_attentions=output_attentions,
|
368 |
+
)
|
369 |
+
attention_output = cross_attention_outputs[0]
|
370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
371 |
+
layer_output = apply_chunking_to_forward(
|
372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
373 |
+
)
|
374 |
+
outputs = (layer_output,) + outputs
|
375 |
+
|
376 |
+
outputs = outputs + (present_key_value,)
|
377 |
+
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
def feed_forward_chunk(self, attention_output):
|
381 |
+
intermediate_output = self.intermediate(attention_output)
|
382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
383 |
+
return layer_output
|
384 |
+
|
385 |
+
|
386 |
+
class BertEncoder(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
super().__init__()
|
389 |
+
self.config = config
|
390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def forward(
|
394 |
+
self,
|
395 |
+
hidden_states,
|
396 |
+
attention_mask=None,
|
397 |
+
head_mask=None,
|
398 |
+
encoder_hidden_states=None,
|
399 |
+
encoder_attention_mask=None,
|
400 |
+
past_key_values=None,
|
401 |
+
use_cache=None,
|
402 |
+
output_attentions=False,
|
403 |
+
output_hidden_states=False,
|
404 |
+
return_dict=True,
|
405 |
+
mode='multimodal',
|
406 |
+
):
|
407 |
+
all_hidden_states = () if output_hidden_states else None
|
408 |
+
all_self_attentions = () if output_attentions else None
|
409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
410 |
+
|
411 |
+
next_decoder_cache = () if use_cache else None
|
412 |
+
|
413 |
+
for i in range(self.config.num_hidden_layers):
|
414 |
+
layer_module = self.layer[i]
|
415 |
+
if output_hidden_states:
|
416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
417 |
+
|
418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
420 |
+
|
421 |
+
if self.gradient_checkpointing and self.training:
|
422 |
+
|
423 |
+
if use_cache:
|
424 |
+
logger.warn(
|
425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
426 |
+
)
|
427 |
+
use_cache = False
|
428 |
+
|
429 |
+
def create_custom_forward(module):
|
430 |
+
def custom_forward(*inputs):
|
431 |
+
return module(*inputs, past_key_value, output_attentions)
|
432 |
+
|
433 |
+
return custom_forward
|
434 |
+
|
435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
436 |
+
create_custom_forward(layer_module),
|
437 |
+
hidden_states,
|
438 |
+
attention_mask,
|
439 |
+
layer_head_mask,
|
440 |
+
encoder_hidden_states,
|
441 |
+
encoder_attention_mask,
|
442 |
+
mode=mode,
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
layer_outputs = layer_module(
|
446 |
+
hidden_states,
|
447 |
+
attention_mask,
|
448 |
+
layer_head_mask,
|
449 |
+
encoder_hidden_states,
|
450 |
+
encoder_attention_mask,
|
451 |
+
past_key_value,
|
452 |
+
output_attentions,
|
453 |
+
mode=mode,
|
454 |
+
)
|
455 |
+
|
456 |
+
hidden_states = layer_outputs[0]
|
457 |
+
if use_cache:
|
458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
459 |
+
if output_attentions:
|
460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
461 |
+
|
462 |
+
if output_hidden_states:
|
463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
464 |
+
|
465 |
+
if not return_dict:
|
466 |
+
return tuple(
|
467 |
+
v
|
468 |
+
for v in [
|
469 |
+
hidden_states,
|
470 |
+
next_decoder_cache,
|
471 |
+
all_hidden_states,
|
472 |
+
all_self_attentions,
|
473 |
+
all_cross_attentions,
|
474 |
+
]
|
475 |
+
if v is not None
|
476 |
+
)
|
477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
478 |
+
last_hidden_state=hidden_states,
|
479 |
+
past_key_values=next_decoder_cache,
|
480 |
+
hidden_states=all_hidden_states,
|
481 |
+
attentions=all_self_attentions,
|
482 |
+
cross_attentions=all_cross_attentions,
|
483 |
+
)
|
484 |
+
|
485 |
+
|
486 |
+
class BertPooler(nn.Module):
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__()
|
489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
490 |
+
self.activation = nn.Tanh()
|
491 |
+
|
492 |
+
def forward(self, hidden_states):
|
493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
494 |
+
# to the first token.
|
495 |
+
first_token_tensor = hidden_states[:, 0]
|
496 |
+
pooled_output = self.dense(first_token_tensor)
|
497 |
+
pooled_output = self.activation(pooled_output)
|
498 |
+
return pooled_output
|
499 |
+
|
500 |
+
|
501 |
+
class BertPredictionHeadTransform(nn.Module):
|
502 |
+
def __init__(self, config):
|
503 |
+
super().__init__()
|
504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
505 |
+
if isinstance(config.hidden_act, str):
|
506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
507 |
+
else:
|
508 |
+
self.transform_act_fn = config.hidden_act
|
509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
510 |
+
|
511 |
+
def forward(self, hidden_states):
|
512 |
+
hidden_states = self.dense(hidden_states)
|
513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
515 |
+
return hidden_states
|
516 |
+
|
517 |
+
|
518 |
+
class BertLMPredictionHead(nn.Module):
|
519 |
+
def __init__(self, config):
|
520 |
+
super().__init__()
|
521 |
+
self.transform = BertPredictionHeadTransform(config)
|
522 |
+
|
523 |
+
# The output weights are the same as the input embeddings, but there is
|
524 |
+
# an output-only bias for each token.
|
525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
526 |
+
|
527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
528 |
+
|
529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
530 |
+
self.decoder.bias = self.bias
|
531 |
+
|
532 |
+
def forward(self, hidden_states):
|
533 |
+
hidden_states = self.transform(hidden_states)
|
534 |
+
hidden_states = self.decoder(hidden_states)
|
535 |
+
return hidden_states
|
536 |
+
|
537 |
+
|
538 |
+
class BertOnlyMLMHead(nn.Module):
|
539 |
+
def __init__(self, config):
|
540 |
+
super().__init__()
|
541 |
+
self.predictions = BertLMPredictionHead(config)
|
542 |
+
|
543 |
+
def forward(self, sequence_output):
|
544 |
+
prediction_scores = self.predictions(sequence_output)
|
545 |
+
return prediction_scores
|
546 |
+
|
547 |
+
|
548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
549 |
+
"""
|
550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
551 |
+
models.
|
552 |
+
"""
|
553 |
+
|
554 |
+
config_class = BertConfig
|
555 |
+
base_model_prefix = "bert"
|
556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
557 |
+
|
558 |
+
def _init_weights(self, module):
|
559 |
+
""" Initialize the weights """
|
560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
564 |
+
elif isinstance(module, nn.LayerNorm):
|
565 |
+
module.bias.data.zero_()
|
566 |
+
module.weight.data.fill_(1.0)
|
567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
568 |
+
module.bias.data.zero_()
|
569 |
+
|
570 |
+
|
571 |
+
class BertModel(BertPreTrainedModel):
|
572 |
+
"""
|
573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
578 |
+
input to the forward pass.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, config, add_pooling_layer=True):
|
582 |
+
super().__init__(config)
|
583 |
+
self.config = config
|
584 |
+
|
585 |
+
self.embeddings = BertEmbeddings(config)
|
586 |
+
|
587 |
+
self.encoder = BertEncoder(config)
|
588 |
+
|
589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
590 |
+
|
591 |
+
self.init_weights()
|
592 |
+
|
593 |
+
|
594 |
+
def get_input_embeddings(self):
|
595 |
+
return self.embeddings.word_embeddings
|
596 |
+
|
597 |
+
def set_input_embeddings(self, value):
|
598 |
+
self.embeddings.word_embeddings = value
|
599 |
+
|
600 |
+
def _prune_heads(self, heads_to_prune):
|
601 |
+
"""
|
602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
603 |
+
class PreTrainedModel
|
604 |
+
"""
|
605 |
+
for layer, heads in heads_to_prune.items():
|
606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
607 |
+
|
608 |
+
|
609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
610 |
+
"""
|
611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
612 |
+
|
613 |
+
Arguments:
|
614 |
+
attention_mask (:obj:`torch.Tensor`):
|
615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
616 |
+
input_shape (:obj:`Tuple[int]`):
|
617 |
+
The shape of the input to the model.
|
618 |
+
device: (:obj:`torch.device`):
|
619 |
+
The device of the input to the model.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
623 |
+
"""
|
624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
626 |
+
if attention_mask.dim() == 3:
|
627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
628 |
+
elif attention_mask.dim() == 2:
|
629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
632 |
+
if is_decoder:
|
633 |
+
batch_size, seq_length = input_shape
|
634 |
+
|
635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
640 |
+
|
641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
643 |
+
causal_mask = torch.cat(
|
644 |
+
[
|
645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
646 |
+
causal_mask,
|
647 |
+
],
|
648 |
+
axis=-1,
|
649 |
+
)
|
650 |
+
|
651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
652 |
+
else:
|
653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
654 |
+
else:
|
655 |
+
raise ValueError(
|
656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
657 |
+
input_shape, attention_mask.shape
|
658 |
+
)
|
659 |
+
)
|
660 |
+
|
661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
665 |
+
# effectively the same as removing these entirely.
|
666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
668 |
+
return extended_attention_mask
|
669 |
+
|
670 |
+
def forward(
|
671 |
+
self,
|
672 |
+
input_ids=None,
|
673 |
+
attention_mask=None,
|
674 |
+
position_ids=None,
|
675 |
+
head_mask=None,
|
676 |
+
inputs_embeds=None,
|
677 |
+
encoder_embeds=None,
|
678 |
+
encoder_hidden_states=None,
|
679 |
+
encoder_attention_mask=None,
|
680 |
+
past_key_values=None,
|
681 |
+
use_cache=None,
|
682 |
+
output_attentions=None,
|
683 |
+
output_hidden_states=None,
|
684 |
+
return_dict=None,
|
685 |
+
is_decoder=False,
|
686 |
+
mode='multimodal',
|
687 |
+
):
|
688 |
+
r"""
|
689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
691 |
+
the model is configured as a decoder.
|
692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
695 |
+
- 1 for tokens that are **not masked**,
|
696 |
+
- 0 for tokens that are **masked**.
|
697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
702 |
+
use_cache (:obj:`bool`, `optional`):
|
703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
704 |
+
decoding (see :obj:`past_key_values`).
|
705 |
+
"""
|
706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
707 |
+
output_hidden_states = (
|
708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
709 |
+
)
|
710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
711 |
+
|
712 |
+
if is_decoder:
|
713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
714 |
+
else:
|
715 |
+
use_cache = False
|
716 |
+
|
717 |
+
if input_ids is not None and inputs_embeds is not None:
|
718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
719 |
+
elif input_ids is not None:
|
720 |
+
input_shape = input_ids.size()
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = input_ids.device
|
723 |
+
elif inputs_embeds is not None:
|
724 |
+
input_shape = inputs_embeds.size()[:-1]
|
725 |
+
batch_size, seq_length = input_shape
|
726 |
+
device = inputs_embeds.device
|
727 |
+
elif encoder_embeds is not None:
|
728 |
+
input_shape = encoder_embeds.size()[:-1]
|
729 |
+
batch_size, seq_length = input_shape
|
730 |
+
device = encoder_embeds.device
|
731 |
+
else:
|
732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
733 |
+
|
734 |
+
# past_key_values_length
|
735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
736 |
+
|
737 |
+
if attention_mask is None:
|
738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
739 |
+
|
740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
743 |
+
device, is_decoder)
|
744 |
+
|
745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
747 |
+
if encoder_hidden_states is not None:
|
748 |
+
if type(encoder_hidden_states) == list:
|
749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
750 |
+
else:
|
751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
753 |
+
|
754 |
+
if type(encoder_attention_mask) == list:
|
755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
756 |
+
elif encoder_attention_mask is None:
|
757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
759 |
+
else:
|
760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
761 |
+
else:
|
762 |
+
encoder_extended_attention_mask = None
|
763 |
+
|
764 |
+
# Prepare head mask if needed
|
765 |
+
# 1.0 in head_mask indicate we keep the head
|
766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
770 |
+
|
771 |
+
if encoder_embeds is None:
|
772 |
+
embedding_output = self.embeddings(
|
773 |
+
input_ids=input_ids,
|
774 |
+
position_ids=position_ids,
|
775 |
+
inputs_embeds=inputs_embeds,
|
776 |
+
past_key_values_length=past_key_values_length,
|
777 |
+
)
|
778 |
+
else:
|
779 |
+
embedding_output = encoder_embeds
|
780 |
+
|
781 |
+
encoder_outputs = self.encoder(
|
782 |
+
embedding_output,
|
783 |
+
attention_mask=extended_attention_mask,
|
784 |
+
head_mask=head_mask,
|
785 |
+
encoder_hidden_states=encoder_hidden_states,
|
786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
787 |
+
past_key_values=past_key_values,
|
788 |
+
use_cache=use_cache,
|
789 |
+
output_attentions=output_attentions,
|
790 |
+
output_hidden_states=output_hidden_states,
|
791 |
+
return_dict=return_dict,
|
792 |
+
mode=mode,
|
793 |
+
)
|
794 |
+
sequence_output = encoder_outputs[0]
|
795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
796 |
+
|
797 |
+
if not return_dict:
|
798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
799 |
+
|
800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
801 |
+
last_hidden_state=sequence_output,
|
802 |
+
pooler_output=pooled_output,
|
803 |
+
past_key_values=encoder_outputs.past_key_values,
|
804 |
+
hidden_states=encoder_outputs.hidden_states,
|
805 |
+
attentions=encoder_outputs.attentions,
|
806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
807 |
+
)
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
812 |
+
|
813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
815 |
+
|
816 |
+
def __init__(self, config):
|
817 |
+
super().__init__(config)
|
818 |
+
|
819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
820 |
+
self.cls = BertOnlyMLMHead(config)
|
821 |
+
|
822 |
+
self.init_weights()
|
823 |
+
|
824 |
+
def get_output_embeddings(self):
|
825 |
+
return self.cls.predictions.decoder
|
826 |
+
|
827 |
+
def set_output_embeddings(self, new_embeddings):
|
828 |
+
self.cls.predictions.decoder = new_embeddings
|
829 |
+
|
830 |
+
def forward(
|
831 |
+
self,
|
832 |
+
input_ids=None,
|
833 |
+
attention_mask=None,
|
834 |
+
position_ids=None,
|
835 |
+
head_mask=None,
|
836 |
+
inputs_embeds=None,
|
837 |
+
encoder_hidden_states=None,
|
838 |
+
encoder_attention_mask=None,
|
839 |
+
labels=None,
|
840 |
+
past_key_values=None,
|
841 |
+
use_cache=None,
|
842 |
+
output_attentions=None,
|
843 |
+
output_hidden_states=None,
|
844 |
+
return_dict=None,
|
845 |
+
return_logits=False,
|
846 |
+
is_decoder=True,
|
847 |
+
reduction='mean',
|
848 |
+
mode='multimodal',
|
849 |
+
):
|
850 |
+
r"""
|
851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
853 |
+
the model is configured as a decoder.
|
854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
857 |
+
- 1 for tokens that are **not masked**,
|
858 |
+
- 0 for tokens that are **masked**.
|
859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
868 |
+
use_cache (:obj:`bool`, `optional`):
|
869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
870 |
+
decoding (see :obj:`past_key_values`).
|
871 |
+
Returns:
|
872 |
+
Example::
|
873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
874 |
+
>>> import torch
|
875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
879 |
+
>>> outputs = model(**inputs)
|
880 |
+
>>> prediction_logits = outputs.logits
|
881 |
+
"""
|
882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
883 |
+
if labels is not None:
|
884 |
+
use_cache = False
|
885 |
+
|
886 |
+
outputs = self.bert(
|
887 |
+
input_ids,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
position_ids=position_ids,
|
890 |
+
head_mask=head_mask,
|
891 |
+
inputs_embeds=inputs_embeds,
|
892 |
+
encoder_hidden_states=encoder_hidden_states,
|
893 |
+
encoder_attention_mask=encoder_attention_mask,
|
894 |
+
past_key_values=past_key_values,
|
895 |
+
use_cache=use_cache,
|
896 |
+
output_attentions=output_attentions,
|
897 |
+
output_hidden_states=output_hidden_states,
|
898 |
+
return_dict=return_dict,
|
899 |
+
is_decoder=is_decoder,
|
900 |
+
mode=mode,
|
901 |
+
)
|
902 |
+
|
903 |
+
sequence_output = outputs[0]
|
904 |
+
prediction_scores = self.cls(sequence_output)
|
905 |
+
|
906 |
+
if return_logits:
|
907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
908 |
+
|
909 |
+
lm_loss = None
|
910 |
+
if labels is not None:
|
911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
913 |
+
labels = labels[:, 1:].contiguous()
|
914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
916 |
+
if reduction=='none':
|
917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
output = (prediction_scores,) + outputs[2:]
|
921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
922 |
+
|
923 |
+
return CausalLMOutputWithCrossAttentions(
|
924 |
+
loss=lm_loss,
|
925 |
+
logits=prediction_scores,
|
926 |
+
past_key_values=outputs.past_key_values,
|
927 |
+
hidden_states=outputs.hidden_states,
|
928 |
+
attentions=outputs.attentions,
|
929 |
+
cross_attentions=outputs.cross_attentions,
|
930 |
+
)
|
931 |
+
|
932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
933 |
+
input_shape = input_ids.shape
|
934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
935 |
+
if attention_mask is None:
|
936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
937 |
+
|
938 |
+
# cut decoder_input_ids if past is used
|
939 |
+
if past is not None:
|
940 |
+
input_ids = input_ids[:, -1:]
|
941 |
+
|
942 |
+
return {
|
943 |
+
"input_ids": input_ids,
|
944 |
+
"attention_mask": attention_mask,
|
945 |
+
"past_key_values": past,
|
946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
948 |
+
"is_decoder": True,
|
949 |
+
}
|
950 |
+
|
951 |
+
def _reorder_cache(self, past, beam_idx):
|
952 |
+
reordered_past = ()
|
953 |
+
for layer_past in past:
|
954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
955 |
+
return reordered_past
|
models/nlvr_encoder.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, device, dtype, nn
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import CrossEntropyLoss
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.file_utils import (
|
16 |
+
ModelOutput,
|
17 |
+
)
|
18 |
+
from transformers.modeling_outputs import (
|
19 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
20 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
21 |
+
CausalLMOutputWithCrossAttentions,
|
22 |
+
MaskedLMOutput,
|
23 |
+
MultipleChoiceModelOutput,
|
24 |
+
NextSentencePredictorOutput,
|
25 |
+
QuestionAnsweringModelOutput,
|
26 |
+
SequenceClassifierOutput,
|
27 |
+
TokenClassifierOutput,
|
28 |
+
)
|
29 |
+
from transformers.modeling_utils import (
|
30 |
+
PreTrainedModel,
|
31 |
+
apply_chunking_to_forward,
|
32 |
+
find_pruneable_heads_and_indices,
|
33 |
+
prune_linear_layer,
|
34 |
+
)
|
35 |
+
from transformers.utils import logging
|
36 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
|
42 |
+
class BertEmbeddings(nn.Module):
|
43 |
+
"""Construct the embeddings from word and position embeddings."""
|
44 |
+
|
45 |
+
def __init__(self, config):
|
46 |
+
super().__init__()
|
47 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
48 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
49 |
+
|
50 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
51 |
+
# any TensorFlow checkpoint file
|
52 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
53 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
54 |
+
|
55 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
56 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
57 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
58 |
+
|
59 |
+
self.config = config
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
63 |
+
):
|
64 |
+
if input_ids is not None:
|
65 |
+
input_shape = input_ids.size()
|
66 |
+
else:
|
67 |
+
input_shape = inputs_embeds.size()[:-1]
|
68 |
+
|
69 |
+
seq_length = input_shape[1]
|
70 |
+
|
71 |
+
if position_ids is None:
|
72 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
73 |
+
|
74 |
+
if inputs_embeds is None:
|
75 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
76 |
+
|
77 |
+
embeddings = inputs_embeds
|
78 |
+
|
79 |
+
if self.position_embedding_type == "absolute":
|
80 |
+
position_embeddings = self.position_embeddings(position_ids)
|
81 |
+
embeddings += position_embeddings
|
82 |
+
embeddings = self.LayerNorm(embeddings)
|
83 |
+
embeddings = self.dropout(embeddings)
|
84 |
+
return embeddings
|
85 |
+
|
86 |
+
|
87 |
+
class BertSelfAttention(nn.Module):
|
88 |
+
def __init__(self, config, is_cross_attention):
|
89 |
+
super().__init__()
|
90 |
+
self.config = config
|
91 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
92 |
+
raise ValueError(
|
93 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
94 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
95 |
+
)
|
96 |
+
|
97 |
+
self.num_attention_heads = config.num_attention_heads
|
98 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
99 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
100 |
+
|
101 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
102 |
+
if is_cross_attention:
|
103 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
104 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
105 |
+
else:
|
106 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
107 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
108 |
+
|
109 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
110 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
111 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
112 |
+
self.max_position_embeddings = config.max_position_embeddings
|
113 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
114 |
+
self.save_attention = False
|
115 |
+
|
116 |
+
def save_attn_gradients(self, attn_gradients):
|
117 |
+
self.attn_gradients = attn_gradients
|
118 |
+
|
119 |
+
def get_attn_gradients(self):
|
120 |
+
return self.attn_gradients
|
121 |
+
|
122 |
+
def save_attention_map(self, attention_map):
|
123 |
+
self.attention_map = attention_map
|
124 |
+
|
125 |
+
def get_attention_map(self):
|
126 |
+
return self.attention_map
|
127 |
+
|
128 |
+
def transpose_for_scores(self, x):
|
129 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
130 |
+
x = x.view(*new_x_shape)
|
131 |
+
return x.permute(0, 2, 1, 3)
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
hidden_states,
|
136 |
+
attention_mask=None,
|
137 |
+
head_mask=None,
|
138 |
+
encoder_hidden_states=None,
|
139 |
+
encoder_attention_mask=None,
|
140 |
+
past_key_value=None,
|
141 |
+
output_attentions=False,
|
142 |
+
):
|
143 |
+
mixed_query_layer = self.query(hidden_states)
|
144 |
+
|
145 |
+
# If this is instantiated as a cross-attention module, the keys
|
146 |
+
# and values come from an encoder; the attention mask needs to be
|
147 |
+
# such that the encoder's padding tokens are not attended to.
|
148 |
+
is_cross_attention = encoder_hidden_states is not None
|
149 |
+
|
150 |
+
if is_cross_attention:
|
151 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
152 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
153 |
+
attention_mask = encoder_attention_mask
|
154 |
+
elif past_key_value is not None:
|
155 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
156 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
157 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
158 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
159 |
+
else:
|
160 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
161 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
162 |
+
|
163 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
164 |
+
|
165 |
+
past_key_value = (key_layer, value_layer)
|
166 |
+
|
167 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
168 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
169 |
+
|
170 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
171 |
+
seq_length = hidden_states.size()[1]
|
172 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
173 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
174 |
+
distance = position_ids_l - position_ids_r
|
175 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
176 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
177 |
+
|
178 |
+
if self.position_embedding_type == "relative_key":
|
179 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
180 |
+
attention_scores = attention_scores + relative_position_scores
|
181 |
+
elif self.position_embedding_type == "relative_key_query":
|
182 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
183 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
184 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
185 |
+
|
186 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
187 |
+
if attention_mask is not None:
|
188 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
189 |
+
attention_scores = attention_scores + attention_mask
|
190 |
+
|
191 |
+
# Normalize the attention scores to probabilities.
|
192 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
193 |
+
|
194 |
+
if is_cross_attention and self.save_attention:
|
195 |
+
self.save_attention_map(attention_probs)
|
196 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
197 |
+
|
198 |
+
# This is actually dropping out entire tokens to attend to, which might
|
199 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
200 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
201 |
+
|
202 |
+
# Mask heads if we want to
|
203 |
+
if head_mask is not None:
|
204 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
205 |
+
|
206 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
207 |
+
|
208 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
209 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
210 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
211 |
+
|
212 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
213 |
+
|
214 |
+
outputs = outputs + (past_key_value,)
|
215 |
+
return outputs
|
216 |
+
|
217 |
+
|
218 |
+
class BertSelfOutput(nn.Module):
|
219 |
+
def __init__(self, config, twin=False, merge=False):
|
220 |
+
super().__init__()
|
221 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
223 |
+
if twin:
|
224 |
+
self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
|
225 |
+
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
|
226 |
+
else:
|
227 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
228 |
+
if merge:
|
229 |
+
self.act = ACT2FN[config.hidden_act]
|
230 |
+
self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
231 |
+
self.merge = True
|
232 |
+
else:
|
233 |
+
self.merge = False
|
234 |
+
|
235 |
+
def forward(self, hidden_states, input_tensor):
|
236 |
+
if type(hidden_states) == list:
|
237 |
+
hidden_states0 = self.dense0(hidden_states[0])
|
238 |
+
hidden_states1 = self.dense1(hidden_states[1])
|
239 |
+
if self.merge:
|
240 |
+
#hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
|
241 |
+
hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
|
242 |
+
else:
|
243 |
+
hidden_states = (hidden_states0+hidden_states1)/2
|
244 |
+
else:
|
245 |
+
hidden_states = self.dense(hidden_states)
|
246 |
+
hidden_states = self.dropout(hidden_states)
|
247 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
248 |
+
return hidden_states
|
249 |
+
|
250 |
+
|
251 |
+
class BertAttention(nn.Module):
|
252 |
+
def __init__(self, config, is_cross_attention=False, layer_num=-1):
|
253 |
+
super().__init__()
|
254 |
+
if is_cross_attention:
|
255 |
+
self.self0 = BertSelfAttention(config, is_cross_attention)
|
256 |
+
self.self1 = BertSelfAttention(config, is_cross_attention)
|
257 |
+
else:
|
258 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
259 |
+
self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
|
260 |
+
self.pruned_heads = set()
|
261 |
+
|
262 |
+
def prune_heads(self, heads):
|
263 |
+
if len(heads) == 0:
|
264 |
+
return
|
265 |
+
heads, index = find_pruneable_heads_and_indices(
|
266 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
267 |
+
)
|
268 |
+
|
269 |
+
# Prune linear layers
|
270 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
271 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
272 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
273 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
274 |
+
|
275 |
+
# Update hyper params and store pruned heads
|
276 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
277 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
278 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
279 |
+
|
280 |
+
def forward(
|
281 |
+
self,
|
282 |
+
hidden_states,
|
283 |
+
attention_mask=None,
|
284 |
+
head_mask=None,
|
285 |
+
encoder_hidden_states=None,
|
286 |
+
encoder_attention_mask=None,
|
287 |
+
past_key_value=None,
|
288 |
+
output_attentions=False,
|
289 |
+
):
|
290 |
+
if type(encoder_hidden_states)==list:
|
291 |
+
self_outputs0 = self.self0(
|
292 |
+
hidden_states,
|
293 |
+
attention_mask,
|
294 |
+
head_mask,
|
295 |
+
encoder_hidden_states[0],
|
296 |
+
encoder_attention_mask[0],
|
297 |
+
past_key_value,
|
298 |
+
output_attentions,
|
299 |
+
)
|
300 |
+
self_outputs1 = self.self1(
|
301 |
+
hidden_states,
|
302 |
+
attention_mask,
|
303 |
+
head_mask,
|
304 |
+
encoder_hidden_states[1],
|
305 |
+
encoder_attention_mask[1],
|
306 |
+
past_key_value,
|
307 |
+
output_attentions,
|
308 |
+
)
|
309 |
+
attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
|
310 |
+
|
311 |
+
outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
|
312 |
+
else:
|
313 |
+
self_outputs = self.self(
|
314 |
+
hidden_states,
|
315 |
+
attention_mask,
|
316 |
+
head_mask,
|
317 |
+
encoder_hidden_states,
|
318 |
+
encoder_attention_mask,
|
319 |
+
past_key_value,
|
320 |
+
output_attentions,
|
321 |
+
)
|
322 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
323 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
324 |
+
return outputs
|
325 |
+
|
326 |
+
|
327 |
+
class BertIntermediate(nn.Module):
|
328 |
+
def __init__(self, config):
|
329 |
+
super().__init__()
|
330 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
331 |
+
if isinstance(config.hidden_act, str):
|
332 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
333 |
+
else:
|
334 |
+
self.intermediate_act_fn = config.hidden_act
|
335 |
+
|
336 |
+
def forward(self, hidden_states):
|
337 |
+
hidden_states = self.dense(hidden_states)
|
338 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
339 |
+
return hidden_states
|
340 |
+
|
341 |
+
|
342 |
+
class BertOutput(nn.Module):
|
343 |
+
def __init__(self, config):
|
344 |
+
super().__init__()
|
345 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
346 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
347 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
348 |
+
|
349 |
+
def forward(self, hidden_states, input_tensor):
|
350 |
+
hidden_states = self.dense(hidden_states)
|
351 |
+
hidden_states = self.dropout(hidden_states)
|
352 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
353 |
+
return hidden_states
|
354 |
+
|
355 |
+
|
356 |
+
class BertLayer(nn.Module):
|
357 |
+
def __init__(self, config, layer_num):
|
358 |
+
super().__init__()
|
359 |
+
self.config = config
|
360 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
361 |
+
self.seq_len_dim = 1
|
362 |
+
self.attention = BertAttention(config)
|
363 |
+
self.layer_num = layer_num
|
364 |
+
if self.config.add_cross_attention:
|
365 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
|
366 |
+
self.intermediate = BertIntermediate(config)
|
367 |
+
self.output = BertOutput(config)
|
368 |
+
|
369 |
+
def forward(
|
370 |
+
self,
|
371 |
+
hidden_states,
|
372 |
+
attention_mask=None,
|
373 |
+
head_mask=None,
|
374 |
+
encoder_hidden_states=None,
|
375 |
+
encoder_attention_mask=None,
|
376 |
+
past_key_value=None,
|
377 |
+
output_attentions=False,
|
378 |
+
mode=None,
|
379 |
+
):
|
380 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
381 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
382 |
+
self_attention_outputs = self.attention(
|
383 |
+
hidden_states,
|
384 |
+
attention_mask,
|
385 |
+
head_mask,
|
386 |
+
output_attentions=output_attentions,
|
387 |
+
past_key_value=self_attn_past_key_value,
|
388 |
+
)
|
389 |
+
attention_output = self_attention_outputs[0]
|
390 |
+
|
391 |
+
outputs = self_attention_outputs[1:-1]
|
392 |
+
present_key_value = self_attention_outputs[-1]
|
393 |
+
|
394 |
+
if mode=='multimodal':
|
395 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
396 |
+
cross_attention_outputs = self.crossattention(
|
397 |
+
attention_output,
|
398 |
+
attention_mask,
|
399 |
+
head_mask,
|
400 |
+
encoder_hidden_states,
|
401 |
+
encoder_attention_mask,
|
402 |
+
output_attentions=output_attentions,
|
403 |
+
)
|
404 |
+
attention_output = cross_attention_outputs[0]
|
405 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
406 |
+
layer_output = apply_chunking_to_forward(
|
407 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
408 |
+
)
|
409 |
+
outputs = (layer_output,) + outputs
|
410 |
+
|
411 |
+
outputs = outputs + (present_key_value,)
|
412 |
+
|
413 |
+
return outputs
|
414 |
+
|
415 |
+
def feed_forward_chunk(self, attention_output):
|
416 |
+
intermediate_output = self.intermediate(attention_output)
|
417 |
+
layer_output = self.output(intermediate_output, attention_output)
|
418 |
+
return layer_output
|
419 |
+
|
420 |
+
|
421 |
+
class BertEncoder(nn.Module):
|
422 |
+
def __init__(self, config):
|
423 |
+
super().__init__()
|
424 |
+
self.config = config
|
425 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
426 |
+
self.gradient_checkpointing = False
|
427 |
+
|
428 |
+
def forward(
|
429 |
+
self,
|
430 |
+
hidden_states,
|
431 |
+
attention_mask=None,
|
432 |
+
head_mask=None,
|
433 |
+
encoder_hidden_states=None,
|
434 |
+
encoder_attention_mask=None,
|
435 |
+
past_key_values=None,
|
436 |
+
use_cache=None,
|
437 |
+
output_attentions=False,
|
438 |
+
output_hidden_states=False,
|
439 |
+
return_dict=True,
|
440 |
+
mode='multimodal',
|
441 |
+
):
|
442 |
+
all_hidden_states = () if output_hidden_states else None
|
443 |
+
all_self_attentions = () if output_attentions else None
|
444 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
445 |
+
|
446 |
+
next_decoder_cache = () if use_cache else None
|
447 |
+
|
448 |
+
for i in range(self.config.num_hidden_layers):
|
449 |
+
layer_module = self.layer[i]
|
450 |
+
if output_hidden_states:
|
451 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
452 |
+
|
453 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
454 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
455 |
+
|
456 |
+
if self.gradient_checkpointing and self.training:
|
457 |
+
|
458 |
+
if use_cache:
|
459 |
+
logger.warn(
|
460 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
461 |
+
)
|
462 |
+
use_cache = False
|
463 |
+
|
464 |
+
def create_custom_forward(module):
|
465 |
+
def custom_forward(*inputs):
|
466 |
+
return module(*inputs, past_key_value, output_attentions)
|
467 |
+
|
468 |
+
return custom_forward
|
469 |
+
|
470 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
471 |
+
create_custom_forward(layer_module),
|
472 |
+
hidden_states,
|
473 |
+
attention_mask,
|
474 |
+
layer_head_mask,
|
475 |
+
encoder_hidden_states,
|
476 |
+
encoder_attention_mask,
|
477 |
+
mode=mode,
|
478 |
+
)
|
479 |
+
else:
|
480 |
+
layer_outputs = layer_module(
|
481 |
+
hidden_states,
|
482 |
+
attention_mask,
|
483 |
+
layer_head_mask,
|
484 |
+
encoder_hidden_states,
|
485 |
+
encoder_attention_mask,
|
486 |
+
past_key_value,
|
487 |
+
output_attentions,
|
488 |
+
mode=mode,
|
489 |
+
)
|
490 |
+
|
491 |
+
hidden_states = layer_outputs[0]
|
492 |
+
if use_cache:
|
493 |
+
next_decoder_cache += (layer_outputs[-1],)
|
494 |
+
if output_attentions:
|
495 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
496 |
+
|
497 |
+
if output_hidden_states:
|
498 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
499 |
+
|
500 |
+
if not return_dict:
|
501 |
+
return tuple(
|
502 |
+
v
|
503 |
+
for v in [
|
504 |
+
hidden_states,
|
505 |
+
next_decoder_cache,
|
506 |
+
all_hidden_states,
|
507 |
+
all_self_attentions,
|
508 |
+
all_cross_attentions,
|
509 |
+
]
|
510 |
+
if v is not None
|
511 |
+
)
|
512 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
513 |
+
last_hidden_state=hidden_states,
|
514 |
+
past_key_values=next_decoder_cache,
|
515 |
+
hidden_states=all_hidden_states,
|
516 |
+
attentions=all_self_attentions,
|
517 |
+
cross_attentions=all_cross_attentions,
|
518 |
+
)
|
519 |
+
|
520 |
+
|
521 |
+
class BertPooler(nn.Module):
|
522 |
+
def __init__(self, config):
|
523 |
+
super().__init__()
|
524 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
525 |
+
self.activation = nn.Tanh()
|
526 |
+
|
527 |
+
def forward(self, hidden_states):
|
528 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
529 |
+
# to the first token.
|
530 |
+
first_token_tensor = hidden_states[:, 0]
|
531 |
+
pooled_output = self.dense(first_token_tensor)
|
532 |
+
pooled_output = self.activation(pooled_output)
|
533 |
+
return pooled_output
|
534 |
+
|
535 |
+
|
536 |
+
class BertPredictionHeadTransform(nn.Module):
|
537 |
+
def __init__(self, config):
|
538 |
+
super().__init__()
|
539 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
540 |
+
if isinstance(config.hidden_act, str):
|
541 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
542 |
+
else:
|
543 |
+
self.transform_act_fn = config.hidden_act
|
544 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
545 |
+
|
546 |
+
def forward(self, hidden_states):
|
547 |
+
hidden_states = self.dense(hidden_states)
|
548 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
549 |
+
hidden_states = self.LayerNorm(hidden_states)
|
550 |
+
return hidden_states
|
551 |
+
|
552 |
+
|
553 |
+
class BertLMPredictionHead(nn.Module):
|
554 |
+
def __init__(self, config):
|
555 |
+
super().__init__()
|
556 |
+
self.transform = BertPredictionHeadTransform(config)
|
557 |
+
|
558 |
+
# The output weights are the same as the input embeddings, but there is
|
559 |
+
# an output-only bias for each token.
|
560 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
561 |
+
|
562 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
563 |
+
|
564 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
565 |
+
self.decoder.bias = self.bias
|
566 |
+
|
567 |
+
def forward(self, hidden_states):
|
568 |
+
hidden_states = self.transform(hidden_states)
|
569 |
+
hidden_states = self.decoder(hidden_states)
|
570 |
+
return hidden_states
|
571 |
+
|
572 |
+
|
573 |
+
class BertOnlyMLMHead(nn.Module):
|
574 |
+
def __init__(self, config):
|
575 |
+
super().__init__()
|
576 |
+
self.predictions = BertLMPredictionHead(config)
|
577 |
+
|
578 |
+
def forward(self, sequence_output):
|
579 |
+
prediction_scores = self.predictions(sequence_output)
|
580 |
+
return prediction_scores
|
581 |
+
|
582 |
+
|
583 |
+
class BertPreTrainedModel(PreTrainedModel):
|
584 |
+
"""
|
585 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
586 |
+
models.
|
587 |
+
"""
|
588 |
+
|
589 |
+
config_class = BertConfig
|
590 |
+
base_model_prefix = "bert"
|
591 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
592 |
+
|
593 |
+
def _init_weights(self, module):
|
594 |
+
""" Initialize the weights """
|
595 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
596 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
597 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
598 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
599 |
+
elif isinstance(module, nn.LayerNorm):
|
600 |
+
module.bias.data.zero_()
|
601 |
+
module.weight.data.fill_(1.0)
|
602 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
603 |
+
module.bias.data.zero_()
|
604 |
+
|
605 |
+
|
606 |
+
class BertModel(BertPreTrainedModel):
|
607 |
+
"""
|
608 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
609 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
610 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
611 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
612 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
613 |
+
input to the forward pass.
|
614 |
+
"""
|
615 |
+
|
616 |
+
def __init__(self, config, add_pooling_layer=True):
|
617 |
+
super().__init__(config)
|
618 |
+
self.config = config
|
619 |
+
|
620 |
+
self.embeddings = BertEmbeddings(config)
|
621 |
+
|
622 |
+
self.encoder = BertEncoder(config)
|
623 |
+
|
624 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
625 |
+
|
626 |
+
self.init_weights()
|
627 |
+
|
628 |
+
|
629 |
+
def get_input_embeddings(self):
|
630 |
+
return self.embeddings.word_embeddings
|
631 |
+
|
632 |
+
def set_input_embeddings(self, value):
|
633 |
+
self.embeddings.word_embeddings = value
|
634 |
+
|
635 |
+
def _prune_heads(self, heads_to_prune):
|
636 |
+
"""
|
637 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
638 |
+
class PreTrainedModel
|
639 |
+
"""
|
640 |
+
for layer, heads in heads_to_prune.items():
|
641 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
642 |
+
|
643 |
+
|
644 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
645 |
+
"""
|
646 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
647 |
+
|
648 |
+
Arguments:
|
649 |
+
attention_mask (:obj:`torch.Tensor`):
|
650 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
651 |
+
input_shape (:obj:`Tuple[int]`):
|
652 |
+
The shape of the input to the model.
|
653 |
+
device: (:obj:`torch.device`):
|
654 |
+
The device of the input to the model.
|
655 |
+
|
656 |
+
Returns:
|
657 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
658 |
+
"""
|
659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
661 |
+
if attention_mask.dim() == 3:
|
662 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
663 |
+
elif attention_mask.dim() == 2:
|
664 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
665 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
666 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
667 |
+
if is_decoder:
|
668 |
+
batch_size, seq_length = input_shape
|
669 |
+
|
670 |
+
seq_ids = torch.arange(seq_length, device=device)
|
671 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
672 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
673 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
674 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
675 |
+
|
676 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
677 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
678 |
+
causal_mask = torch.cat(
|
679 |
+
[
|
680 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
681 |
+
causal_mask,
|
682 |
+
],
|
683 |
+
axis=-1,
|
684 |
+
)
|
685 |
+
|
686 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
687 |
+
else:
|
688 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
689 |
+
else:
|
690 |
+
raise ValueError(
|
691 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
692 |
+
input_shape, attention_mask.shape
|
693 |
+
)
|
694 |
+
)
|
695 |
+
|
696 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
697 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
698 |
+
# positions we want to attend and -10000.0 for masked positions.
|
699 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
700 |
+
# effectively the same as removing these entirely.
|
701 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
702 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
703 |
+
return extended_attention_mask
|
704 |
+
|
705 |
+
def forward(
|
706 |
+
self,
|
707 |
+
input_ids=None,
|
708 |
+
attention_mask=None,
|
709 |
+
position_ids=None,
|
710 |
+
head_mask=None,
|
711 |
+
inputs_embeds=None,
|
712 |
+
encoder_embeds=None,
|
713 |
+
encoder_hidden_states=None,
|
714 |
+
encoder_attention_mask=None,
|
715 |
+
past_key_values=None,
|
716 |
+
use_cache=None,
|
717 |
+
output_attentions=None,
|
718 |
+
output_hidden_states=None,
|
719 |
+
return_dict=None,
|
720 |
+
is_decoder=False,
|
721 |
+
mode='multimodal',
|
722 |
+
):
|
723 |
+
r"""
|
724 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
725 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
726 |
+
the model is configured as a decoder.
|
727 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
728 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
729 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
730 |
+
- 1 for tokens that are **not masked**,
|
731 |
+
- 0 for tokens that are **masked**.
|
732 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
733 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
734 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
735 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
736 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
737 |
+
use_cache (:obj:`bool`, `optional`):
|
738 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
739 |
+
decoding (see :obj:`past_key_values`).
|
740 |
+
"""
|
741 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
742 |
+
output_hidden_states = (
|
743 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
744 |
+
)
|
745 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
746 |
+
|
747 |
+
if is_decoder:
|
748 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
749 |
+
else:
|
750 |
+
use_cache = False
|
751 |
+
|
752 |
+
if input_ids is not None and inputs_embeds is not None:
|
753 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
754 |
+
elif input_ids is not None:
|
755 |
+
input_shape = input_ids.size()
|
756 |
+
batch_size, seq_length = input_shape
|
757 |
+
device = input_ids.device
|
758 |
+
elif inputs_embeds is not None:
|
759 |
+
input_shape = inputs_embeds.size()[:-1]
|
760 |
+
batch_size, seq_length = input_shape
|
761 |
+
device = inputs_embeds.device
|
762 |
+
elif encoder_embeds is not None:
|
763 |
+
input_shape = encoder_embeds.size()[:-1]
|
764 |
+
batch_size, seq_length = input_shape
|
765 |
+
device = encoder_embeds.device
|
766 |
+
else:
|
767 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
768 |
+
|
769 |
+
# past_key_values_length
|
770 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
771 |
+
|
772 |
+
if attention_mask is None:
|
773 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
774 |
+
|
775 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
776 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
777 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
778 |
+
device, is_decoder)
|
779 |
+
|
780 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
781 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
782 |
+
if encoder_hidden_states is not None:
|
783 |
+
if type(encoder_hidden_states) == list:
|
784 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
785 |
+
else:
|
786 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
787 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
788 |
+
|
789 |
+
if type(encoder_attention_mask) == list:
|
790 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
791 |
+
elif encoder_attention_mask is None:
|
792 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
793 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
794 |
+
else:
|
795 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
796 |
+
else:
|
797 |
+
encoder_extended_attention_mask = None
|
798 |
+
|
799 |
+
# Prepare head mask if needed
|
800 |
+
# 1.0 in head_mask indicate we keep the head
|
801 |
+
# attention_probs has shape bsz x n_heads x N x N
|
802 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
803 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
804 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
805 |
+
|
806 |
+
if encoder_embeds is None:
|
807 |
+
embedding_output = self.embeddings(
|
808 |
+
input_ids=input_ids,
|
809 |
+
position_ids=position_ids,
|
810 |
+
inputs_embeds=inputs_embeds,
|
811 |
+
past_key_values_length=past_key_values_length,
|
812 |
+
)
|
813 |
+
else:
|
814 |
+
embedding_output = encoder_embeds
|
815 |
+
|
816 |
+
encoder_outputs = self.encoder(
|
817 |
+
embedding_output,
|
818 |
+
attention_mask=extended_attention_mask,
|
819 |
+
head_mask=head_mask,
|
820 |
+
encoder_hidden_states=encoder_hidden_states,
|
821 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
822 |
+
past_key_values=past_key_values,
|
823 |
+
use_cache=use_cache,
|
824 |
+
output_attentions=output_attentions,
|
825 |
+
output_hidden_states=output_hidden_states,
|
826 |
+
return_dict=return_dict,
|
827 |
+
mode=mode,
|
828 |
+
)
|
829 |
+
sequence_output = encoder_outputs[0]
|
830 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
831 |
+
|
832 |
+
if not return_dict:
|
833 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
834 |
+
|
835 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
836 |
+
last_hidden_state=sequence_output,
|
837 |
+
pooler_output=pooled_output,
|
838 |
+
past_key_values=encoder_outputs.past_key_values,
|
839 |
+
hidden_states=encoder_outputs.hidden_states,
|
840 |
+
attentions=encoder_outputs.attentions,
|
841 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
842 |
+
)
|
843 |
+
|
models/vit.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on timm code base
|
8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
17 |
+
from timm.models.registry import register_model
|
18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
20 |
+
|
21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
22 |
+
|
23 |
+
class Mlp(nn.Module):
|
24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
25 |
+
"""
|
26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class Attention(nn.Module):
|
45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
self.attn_gradients = None
|
56 |
+
self.attention_map = None
|
57 |
+
|
58 |
+
def save_attn_gradients(self, attn_gradients):
|
59 |
+
self.attn_gradients = attn_gradients
|
60 |
+
|
61 |
+
def get_attn_gradients(self):
|
62 |
+
return self.attn_gradients
|
63 |
+
|
64 |
+
def save_attention_map(self, attention_map):
|
65 |
+
self.attention_map = attention_map
|
66 |
+
|
67 |
+
def get_attention_map(self):
|
68 |
+
return self.attention_map
|
69 |
+
|
70 |
+
def forward(self, x, register_hook=False):
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
if register_hook:
|
80 |
+
self.save_attention_map(attn)
|
81 |
+
attn.register_hook(self.save_attn_gradients)
|
82 |
+
|
83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
84 |
+
x = self.proj(x)
|
85 |
+
x = self.proj_drop(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Block(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
93 |
+
super().__init__()
|
94 |
+
self.norm1 = norm_layer(dim)
|
95 |
+
self.attn = Attention(
|
96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
99 |
+
self.norm2 = norm_layer(dim)
|
100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
102 |
+
|
103 |
+
if use_grad_checkpointing:
|
104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
106 |
+
|
107 |
+
def forward(self, x, register_hook=False):
|
108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class VisionTransformer(nn.Module):
|
114 |
+
""" Vision Transformer
|
115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
116 |
+
https://arxiv.org/abs/2010.11929
|
117 |
+
"""
|
118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
img_size (int, tuple): input image size
|
125 |
+
patch_size (int, tuple): patch size
|
126 |
+
in_chans (int): number of input channels
|
127 |
+
num_classes (int): number of classes for classification head
|
128 |
+
embed_dim (int): embedding dimension
|
129 |
+
depth (int): depth of transformer
|
130 |
+
num_heads (int): number of attention heads
|
131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
132 |
+
qkv_bias (bool): enable bias for qkv if True
|
133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
135 |
+
drop_rate (float): dropout rate
|
136 |
+
attn_drop_rate (float): attention dropout rate
|
137 |
+
drop_path_rate (float): stochastic depth rate
|
138 |
+
norm_layer: (nn.Module): normalization layer
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
143 |
+
|
144 |
+
self.patch_embed = PatchEmbed(
|
145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
146 |
+
|
147 |
+
num_patches = self.patch_embed.num_patches
|
148 |
+
|
149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
152 |
+
|
153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
154 |
+
self.blocks = nn.ModuleList([
|
155 |
+
Block(
|
156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
159 |
+
)
|
160 |
+
for i in range(depth)])
|
161 |
+
self.norm = norm_layer(embed_dim)
|
162 |
+
|
163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
164 |
+
trunc_normal_(self.cls_token, std=.02)
|
165 |
+
self.apply(self._init_weights)
|
166 |
+
|
167 |
+
def _init_weights(self, m):
|
168 |
+
if isinstance(m, nn.Linear):
|
169 |
+
trunc_normal_(m.weight, std=.02)
|
170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
elif isinstance(m, nn.LayerNorm):
|
173 |
+
nn.init.constant_(m.bias, 0)
|
174 |
+
nn.init.constant_(m.weight, 1.0)
|
175 |
+
|
176 |
+
@torch.jit.ignore
|
177 |
+
def no_weight_decay(self):
|
178 |
+
return {'pos_embed', 'cls_token'}
|
179 |
+
|
180 |
+
def forward(self, x, register_blk=-1):
|
181 |
+
B = x.shape[0]
|
182 |
+
x = self.patch_embed(x)
|
183 |
+
|
184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
186 |
+
|
187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
188 |
+
x = self.pos_drop(x)
|
189 |
+
|
190 |
+
for i,blk in enumerate(self.blocks):
|
191 |
+
x = blk(x, register_blk==i)
|
192 |
+
x = self.norm(x)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
@torch.jit.ignore()
|
197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
198 |
+
_load_weights(self, checkpoint_path, prefix)
|
199 |
+
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
204 |
+
"""
|
205 |
+
import numpy as np
|
206 |
+
|
207 |
+
def _n2p(w, t=True):
|
208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
209 |
+
w = w.flatten()
|
210 |
+
if t:
|
211 |
+
if w.ndim == 4:
|
212 |
+
w = w.transpose([3, 2, 0, 1])
|
213 |
+
elif w.ndim == 3:
|
214 |
+
w = w.transpose([2, 0, 1])
|
215 |
+
elif w.ndim == 2:
|
216 |
+
w = w.transpose([1, 0])
|
217 |
+
return torch.from_numpy(w)
|
218 |
+
|
219 |
+
w = np.load(checkpoint_path)
|
220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
221 |
+
prefix = 'opt/target/'
|
222 |
+
|
223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
224 |
+
# hybrid
|
225 |
+
backbone = model.patch_embed.backbone
|
226 |
+
stem_only = not hasattr(backbone, 'stem')
|
227 |
+
stem = backbone if stem_only else backbone.stem
|
228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
231 |
+
if not stem_only:
|
232 |
+
for i, stage in enumerate(backbone.stages):
|
233 |
+
for j, block in enumerate(stage.blocks):
|
234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
235 |
+
for r in range(3):
|
236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
239 |
+
if block.downsample is not None:
|
240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
244 |
+
else:
|
245 |
+
embed_conv_w = adapt_input_conv(
|
246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
254 |
+
model.pos_embed.copy_(pos_embed_w)
|
255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
263 |
+
for i, block in enumerate(model.blocks.children()):
|
264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
274 |
+
for r in range(2):
|
275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
279 |
+
|
280 |
+
|
281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
282 |
+
# interpolate position embedding
|
283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
286 |
+
# height (== width) for the checkpoint position embedding
|
287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
288 |
+
# height (== width) for the new position embedding
|
289 |
+
new_size = int(num_patches ** 0.5)
|
290 |
+
|
291 |
+
if orig_size!=new_size:
|
292 |
+
# class_token and dist_token are kept unchanged
|
293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
294 |
+
# only the position tokens are interpolated
|
295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
302 |
+
|
303 |
+
return new_pos_embed
|
304 |
+
else:
|
305 |
+
return pos_embed_checkpoint
|
pretrain.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_pretrain import blip_pretrain
|
26 |
+
import utils
|
27 |
+
from utils import warmup_lr_schedule, step_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
|
30 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
31 |
+
# train
|
32 |
+
model.train()
|
33 |
+
|
34 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
35 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
|
36 |
+
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
37 |
+
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
38 |
+
metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
39 |
+
|
40 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
41 |
+
print_freq = 50
|
42 |
+
|
43 |
+
if config['laion_path']:
|
44 |
+
data_loader.dataset.reload_laion(epoch)
|
45 |
+
|
46 |
+
data_loader.sampler.set_epoch(epoch)
|
47 |
+
|
48 |
+
for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
49 |
+
|
50 |
+
if epoch==0:
|
51 |
+
warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
|
52 |
+
|
53 |
+
optimizer.zero_grad()
|
54 |
+
|
55 |
+
image = image.to(device,non_blocking=True)
|
56 |
+
|
57 |
+
# ramp up alpha in the first 2 epochs
|
58 |
+
alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
|
59 |
+
|
60 |
+
loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
|
61 |
+
loss = loss_ita + loss_itm + loss_lm
|
62 |
+
|
63 |
+
loss.backward()
|
64 |
+
optimizer.step()
|
65 |
+
|
66 |
+
metric_logger.update(loss_ita=loss_ita.item())
|
67 |
+
metric_logger.update(loss_itm=loss_itm.item())
|
68 |
+
metric_logger.update(loss_lm=loss_lm.item())
|
69 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
70 |
+
|
71 |
+
|
72 |
+
# gather the stats from all processes
|
73 |
+
metric_logger.synchronize_between_processes()
|
74 |
+
print("Averaged stats:", metric_logger.global_avg())
|
75 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
76 |
+
|
77 |
+
|
78 |
+
def main(args, config):
|
79 |
+
utils.init_distributed_mode(args)
|
80 |
+
|
81 |
+
device = torch.device(args.device)
|
82 |
+
|
83 |
+
# fix the seed for reproducibility
|
84 |
+
seed = args.seed + utils.get_rank()
|
85 |
+
torch.manual_seed(seed)
|
86 |
+
np.random.seed(seed)
|
87 |
+
random.seed(seed)
|
88 |
+
cudnn.benchmark = True
|
89 |
+
|
90 |
+
#### Dataset ####
|
91 |
+
print("Creating dataset")
|
92 |
+
datasets = [create_dataset('pretrain', config, min_scale=0.2)]
|
93 |
+
print('number of training samples: %d'%len(datasets[0]))
|
94 |
+
|
95 |
+
num_tasks = utils.get_world_size()
|
96 |
+
global_rank = utils.get_rank()
|
97 |
+
samplers = create_sampler(datasets, [True], num_tasks, global_rank)
|
98 |
+
|
99 |
+
data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
|
100 |
+
|
101 |
+
#### Model ####
|
102 |
+
print("Creating model")
|
103 |
+
model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
|
104 |
+
vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
|
105 |
+
|
106 |
+
model = model.to(device)
|
107 |
+
|
108 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
109 |
+
|
110 |
+
start_epoch = 0
|
111 |
+
if args.checkpoint:
|
112 |
+
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
113 |
+
state_dict = checkpoint['model']
|
114 |
+
model.load_state_dict(state_dict)
|
115 |
+
|
116 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
117 |
+
start_epoch = checkpoint['epoch']+1
|
118 |
+
print('resume checkpoint from %s'%args.checkpoint)
|
119 |
+
|
120 |
+
model_without_ddp = model
|
121 |
+
if args.distributed:
|
122 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
123 |
+
model_without_ddp = model.module
|
124 |
+
|
125 |
+
print("Start training")
|
126 |
+
start_time = time.time()
|
127 |
+
for epoch in range(start_epoch, config['max_epoch']):
|
128 |
+
|
129 |
+
step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
|
130 |
+
|
131 |
+
train_stats = train(model, data_loader, optimizer, epoch, device, config)
|
132 |
+
if utils.is_main_process():
|
133 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
134 |
+
'epoch': epoch,
|
135 |
+
}
|
136 |
+
save_obj = {
|
137 |
+
'model': model_without_ddp.state_dict(),
|
138 |
+
'optimizer': optimizer.state_dict(),
|
139 |
+
'config': config,
|
140 |
+
'epoch': epoch,
|
141 |
+
}
|
142 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
|
143 |
+
|
144 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
145 |
+
f.write(json.dumps(log_stats) + "\n")
|
146 |
+
|
147 |
+
dist.barrier()
|
148 |
+
|
149 |
+
total_time = time.time() - start_time
|
150 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
151 |
+
print('Training time {}'.format(total_time_str))
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
parser.add_argument('--config', default='./configs/pretrain.yaml')
|
157 |
+
parser.add_argument('--output_dir', default='output/Pretrain')
|
158 |
+
parser.add_argument('--checkpoint', default='')
|
159 |
+
parser.add_argument('--evaluate', action='store_true')
|
160 |
+
parser.add_argument('--device', default='cuda')
|
161 |
+
parser.add_argument('--seed', default=42, type=int)
|
162 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
163 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
164 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
165 |
+
args = parser.parse_args()
|
166 |
+
|
167 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
168 |
+
|
169 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
170 |
+
|
171 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
172 |
+
|
173 |
+
main(args, config)
|
train_caption.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.utils import save_result, coco_caption_eval
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
header = 'Train Caption Epoch: [{}]'.format(epoch)
|
39 |
+
print_freq = 50
|
40 |
+
|
41 |
+
for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
loss = model(image, caption)
|
45 |
+
|
46 |
+
optimizer.zero_grad()
|
47 |
+
loss.backward()
|
48 |
+
optimizer.step()
|
49 |
+
|
50 |
+
metric_logger.update(loss=loss.item())
|
51 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
52 |
+
|
53 |
+
# gather the stats from all processes
|
54 |
+
metric_logger.synchronize_between_processes()
|
55 |
+
print("Averaged stats:", metric_logger.global_avg())
|
56 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def evaluate(model, data_loader, device, config):
|
61 |
+
# evaluate
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
65 |
+
header = 'Caption generation:'
|
66 |
+
print_freq = 10
|
67 |
+
|
68 |
+
result = []
|
69 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
70 |
+
|
71 |
+
image = image.to(device)
|
72 |
+
|
73 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
74 |
+
min_length=config['min_length'])
|
75 |
+
|
76 |
+
for caption, img_id in zip(captions, image_id):
|
77 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
|
82 |
+
def main(args, config):
|
83 |
+
utils.init_distributed_mode(args)
|
84 |
+
|
85 |
+
device = torch.device(args.device)
|
86 |
+
|
87 |
+
# fix the seed for reproducibility
|
88 |
+
seed = args.seed + utils.get_rank()
|
89 |
+
torch.manual_seed(seed)
|
90 |
+
np.random.seed(seed)
|
91 |
+
random.seed(seed)
|
92 |
+
cudnn.benchmark = True
|
93 |
+
|
94 |
+
#### Dataset ####
|
95 |
+
print("Creating captioning dataset")
|
96 |
+
train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
|
97 |
+
|
98 |
+
if args.distributed:
|
99 |
+
num_tasks = utils.get_world_size()
|
100 |
+
global_rank = utils.get_rank()
|
101 |
+
samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
|
102 |
+
else:
|
103 |
+
samplers = [None, None, None]
|
104 |
+
|
105 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
106 |
+
batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
|
107 |
+
is_trains=[True, False, False], collate_fns=[None,None,None])
|
108 |
+
|
109 |
+
#### Model ####
|
110 |
+
print("Creating model")
|
111 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
112 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
113 |
+
prompt=config['prompt'])
|
114 |
+
|
115 |
+
model = model.to(device)
|
116 |
+
|
117 |
+
model_without_ddp = model
|
118 |
+
if args.distributed:
|
119 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
120 |
+
model_without_ddp = model.module
|
121 |
+
|
122 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
123 |
+
|
124 |
+
best = 0
|
125 |
+
best_epoch = 0
|
126 |
+
|
127 |
+
print("Start training")
|
128 |
+
start_time = time.time()
|
129 |
+
for epoch in range(0, config['max_epoch']):
|
130 |
+
if not args.evaluate:
|
131 |
+
if args.distributed:
|
132 |
+
train_loader.sampler.set_epoch(epoch)
|
133 |
+
|
134 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
135 |
+
|
136 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
137 |
+
|
138 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
139 |
+
val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
|
140 |
+
|
141 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
142 |
+
test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
|
143 |
+
|
144 |
+
if utils.is_main_process():
|
145 |
+
coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
|
146 |
+
coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
|
147 |
+
|
148 |
+
if args.evaluate:
|
149 |
+
log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
150 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
151 |
+
}
|
152 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
153 |
+
f.write(json.dumps(log_stats) + "\n")
|
154 |
+
else:
|
155 |
+
save_obj = {
|
156 |
+
'model': model_without_ddp.state_dict(),
|
157 |
+
'optimizer': optimizer.state_dict(),
|
158 |
+
'config': config,
|
159 |
+
'epoch': epoch,
|
160 |
+
}
|
161 |
+
|
162 |
+
if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
|
163 |
+
best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
|
164 |
+
best_epoch = epoch
|
165 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
166 |
+
|
167 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
168 |
+
**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
169 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
170 |
+
'epoch': epoch,
|
171 |
+
'best_epoch': best_epoch,
|
172 |
+
}
|
173 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
174 |
+
f.write(json.dumps(log_stats) + "\n")
|
175 |
+
|
176 |
+
if args.evaluate:
|
177 |
+
break
|
178 |
+
dist.barrier()
|
179 |
+
|
180 |
+
total_time = time.time() - start_time
|
181 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
182 |
+
print('Training time {}'.format(total_time_str))
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
parser.add_argument('--config', default='./configs/caption_coco.yaml')
|
188 |
+
parser.add_argument('--output_dir', default='output/Caption_coco')
|
189 |
+
parser.add_argument('--evaluate', action='store_true')
|
190 |
+
parser.add_argument('--device', default='cuda')
|
191 |
+
parser.add_argument('--seed', default=42, type=int)
|
192 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
193 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
194 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
195 |
+
args = parser.parse_args()
|
196 |
+
|
197 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
198 |
+
|
199 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
200 |
+
|
201 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
202 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
203 |
+
|
204 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
205 |
+
|
206 |
+
main(args, config)
|
train_nlvr.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
import json
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
import torch.backends.cudnn as cudnn
|
25 |
+
import torch.distributed as dist
|
26 |
+
|
27 |
+
from models.blip_nlvr import blip_nlvr
|
28 |
+
|
29 |
+
import utils
|
30 |
+
from utils import cosine_lr_schedule, warmup_lr_schedule
|
31 |
+
from data import create_dataset, create_sampler, create_loader
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
step_size = 10
|
44 |
+
|
45 |
+
for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
46 |
+
|
47 |
+
images = torch.cat([image0, image1], dim=0)
|
48 |
+
images, targets = images.to(device), targets.to(device)
|
49 |
+
|
50 |
+
loss = model(images, text, targets=targets, train=True)
|
51 |
+
|
52 |
+
optimizer.zero_grad()
|
53 |
+
loss.backward()
|
54 |
+
optimizer.step()
|
55 |
+
|
56 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
57 |
+
metric_logger.update(loss=loss.item())
|
58 |
+
|
59 |
+
# gather the stats from all processes
|
60 |
+
metric_logger.synchronize_between_processes()
|
61 |
+
print("Averaged stats:", metric_logger.global_avg())
|
62 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
63 |
+
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def evaluate(model, data_loader, device, config):
|
67 |
+
# test
|
68 |
+
model.eval()
|
69 |
+
|
70 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
71 |
+
|
72 |
+
header = 'Evaluation:'
|
73 |
+
print_freq = 50
|
74 |
+
|
75 |
+
for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
|
76 |
+
images = torch.cat([image0, image1], dim=0)
|
77 |
+
images, targets = images.to(device), targets.to(device)
|
78 |
+
|
79 |
+
prediction = model(images, text, targets=targets, train=False)
|
80 |
+
|
81 |
+
_, pred_class = prediction.max(1)
|
82 |
+
accuracy = (targets==pred_class).sum() / targets.size(0)
|
83 |
+
|
84 |
+
metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
|
85 |
+
|
86 |
+
# gather the stats from all processes
|
87 |
+
metric_logger.synchronize_between_processes()
|
88 |
+
|
89 |
+
print("Averaged stats:", metric_logger.global_avg())
|
90 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def main(args, config):
|
95 |
+
utils.init_distributed_mode(args)
|
96 |
+
|
97 |
+
device = torch.device(args.device)
|
98 |
+
|
99 |
+
# fix the seed for reproducibility
|
100 |
+
seed = args.seed + utils.get_rank()
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
np.random.seed(seed)
|
103 |
+
random.seed(seed)
|
104 |
+
cudnn.benchmark = True
|
105 |
+
|
106 |
+
#### Dataset ####
|
107 |
+
print("Creating dataset")
|
108 |
+
datasets = create_dataset('nlvr', config)
|
109 |
+
|
110 |
+
if args.distributed:
|
111 |
+
num_tasks = utils.get_world_size()
|
112 |
+
global_rank = utils.get_rank()
|
113 |
+
samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
|
114 |
+
else:
|
115 |
+
samplers = [None, None, None]
|
116 |
+
|
117 |
+
batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
|
118 |
+
train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
|
119 |
+
num_workers=[4,4,4],is_trains=[True,False,False],
|
120 |
+
collate_fns=[None,None,None])
|
121 |
+
|
122 |
+
#### Model ####
|
123 |
+
print("Creating model")
|
124 |
+
model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
|
125 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
126 |
+
|
127 |
+
model = model.to(device)
|
128 |
+
|
129 |
+
model_without_ddp = model
|
130 |
+
if args.distributed:
|
131 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
132 |
+
model_without_ddp = model.module
|
133 |
+
|
134 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
135 |
+
|
136 |
+
print("Start training")
|
137 |
+
start_time = time.time()
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
for epoch in range(0, config['max_epoch']):
|
142 |
+
if not args.evaluate:
|
143 |
+
if args.distributed:
|
144 |
+
train_loader.sampler.set_epoch(epoch)
|
145 |
+
|
146 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
147 |
+
|
148 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
149 |
+
|
150 |
+
val_stats = evaluate(model, val_loader, device, config)
|
151 |
+
test_stats = evaluate(model, test_loader, device, config)
|
152 |
+
|
153 |
+
if utils.is_main_process():
|
154 |
+
if args.evaluate:
|
155 |
+
log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
|
156 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
157 |
+
}
|
158 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
159 |
+
f.write(json.dumps(log_stats) + "\n")
|
160 |
+
|
161 |
+
else:
|
162 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
163 |
+
**{f'val_{k}': v for k, v in val_stats.items()},
|
164 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
165 |
+
'epoch': epoch,
|
166 |
+
}
|
167 |
+
|
168 |
+
if float(val_stats['acc'])>best:
|
169 |
+
save_obj = {
|
170 |
+
'model': model_without_ddp.state_dict(),
|
171 |
+
'optimizer': optimizer.state_dict(),
|
172 |
+
'config': config,
|
173 |
+
'epoch': epoch,
|
174 |
+
}
|
175 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
176 |
+
best = float(val_stats['acc'])
|
177 |
+
best_epoch = epoch
|
178 |
+
|
179 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
180 |
+
f.write(json.dumps(log_stats) + "\n")
|
181 |
+
if args.evaluate:
|
182 |
+
break
|
183 |
+
|
184 |
+
dist.barrier()
|
185 |
+
|
186 |
+
if utils.is_main_process():
|
187 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
188 |
+
f.write("best epoch: %d"%best_epoch)
|
189 |
+
|
190 |
+
total_time = time.time() - start_time
|
191 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
192 |
+
print('Training time {}'.format(total_time_str))
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
parser = argparse.ArgumentParser()
|
197 |
+
parser.add_argument('--config', default='./configs/nlvr.yaml')
|
198 |
+
parser.add_argument('--output_dir', default='output/NLVR')
|
199 |
+
parser.add_argument('--evaluate', action='store_true')
|
200 |
+
parser.add_argument('--device', default='cuda')
|
201 |
+
parser.add_argument('--seed', default=42, type=int)
|
202 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
203 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
204 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
205 |
+
args = parser.parse_args()
|
206 |
+
|
207 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
208 |
+
|
209 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
210 |
+
|
211 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
212 |
+
|
213 |
+
main(args, config)
|
train_retrieval.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_retrieval import blip_retrieval
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
39 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
40 |
+
print_freq = 50
|
41 |
+
|
42 |
+
for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
43 |
+
image = image.to(device,non_blocking=True)
|
44 |
+
idx = idx.to(device,non_blocking=True)
|
45 |
+
|
46 |
+
if epoch>0:
|
47 |
+
alpha = config['alpha']
|
48 |
+
else:
|
49 |
+
alpha = config['alpha']*min(1,i/len(data_loader))
|
50 |
+
|
51 |
+
loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
|
52 |
+
loss = loss_ita + loss_itm
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
|
58 |
+
metric_logger.update(loss_itm=loss_itm.item())
|
59 |
+
metric_logger.update(loss_ita=loss_ita.item())
|
60 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
61 |
+
|
62 |
+
# gather the stats from all processes
|
63 |
+
metric_logger.synchronize_between_processes()
|
64 |
+
print("Averaged stats:", metric_logger.global_avg())
|
65 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
66 |
+
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def evaluation(model, data_loader, device, config):
|
70 |
+
# test
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
74 |
+
header = 'Evaluation:'
|
75 |
+
|
76 |
+
print('Computing features for evaluation...')
|
77 |
+
start_time = time.time()
|
78 |
+
|
79 |
+
texts = data_loader.dataset.text
|
80 |
+
num_text = len(texts)
|
81 |
+
text_bs = 256
|
82 |
+
text_ids = []
|
83 |
+
text_embeds = []
|
84 |
+
text_atts = []
|
85 |
+
for i in range(0, num_text, text_bs):
|
86 |
+
text = texts[i: min(num_text, i+text_bs)]
|
87 |
+
text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
|
88 |
+
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
89 |
+
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
|
90 |
+
text_embeds.append(text_embed)
|
91 |
+
text_ids.append(text_input.input_ids)
|
92 |
+
text_atts.append(text_input.attention_mask)
|
93 |
+
|
94 |
+
text_embeds = torch.cat(text_embeds,dim=0)
|
95 |
+
text_ids = torch.cat(text_ids,dim=0)
|
96 |
+
text_atts = torch.cat(text_atts,dim=0)
|
97 |
+
text_ids[:,0] = model.tokenizer.enc_token_id
|
98 |
+
|
99 |
+
image_feats = []
|
100 |
+
image_embeds = []
|
101 |
+
for image, img_id in data_loader:
|
102 |
+
image = image.to(device)
|
103 |
+
image_feat = model.visual_encoder(image)
|
104 |
+
image_embed = model.vision_proj(image_feat[:,0,:])
|
105 |
+
image_embed = F.normalize(image_embed,dim=-1)
|
106 |
+
|
107 |
+
image_feats.append(image_feat.cpu())
|
108 |
+
image_embeds.append(image_embed)
|
109 |
+
|
110 |
+
image_feats = torch.cat(image_feats,dim=0)
|
111 |
+
image_embeds = torch.cat(image_embeds,dim=0)
|
112 |
+
|
113 |
+
sims_matrix = image_embeds @ text_embeds.t()
|
114 |
+
score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
|
115 |
+
|
116 |
+
num_tasks = utils.get_world_size()
|
117 |
+
rank = utils.get_rank()
|
118 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
119 |
+
start = rank*step
|
120 |
+
end = min(sims_matrix.size(0),start+step)
|
121 |
+
|
122 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
123 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
124 |
+
|
125 |
+
encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
|
126 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
127 |
+
output = model.text_encoder(text_ids[topk_idx],
|
128 |
+
attention_mask = text_atts[topk_idx],
|
129 |
+
encoder_hidden_states = encoder_output,
|
130 |
+
encoder_attention_mask = encoder_att,
|
131 |
+
return_dict = True,
|
132 |
+
)
|
133 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
134 |
+
score_matrix_i2t[start+i,topk_idx] = score + topk_sim
|
135 |
+
|
136 |
+
sims_matrix = sims_matrix.t()
|
137 |
+
score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
|
138 |
+
|
139 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
140 |
+
start = rank*step
|
141 |
+
end = min(sims_matrix.size(0),start+step)
|
142 |
+
|
143 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
144 |
+
|
145 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
146 |
+
encoder_output = image_feats[topk_idx].to(device)
|
147 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
148 |
+
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
|
149 |
+
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
|
150 |
+
encoder_hidden_states = encoder_output,
|
151 |
+
encoder_attention_mask = encoder_att,
|
152 |
+
return_dict = True,
|
153 |
+
)
|
154 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
155 |
+
score_matrix_t2i[start+i,topk_idx] = score + topk_sim
|
156 |
+
|
157 |
+
if args.distributed:
|
158 |
+
dist.barrier()
|
159 |
+
torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
|
160 |
+
torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
|
161 |
+
|
162 |
+
total_time = time.time() - start_time
|
163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
164 |
+
print('Evaluation time {}'.format(total_time_str))
|
165 |
+
|
166 |
+
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
|
172 |
+
|
173 |
+
#Images->Text
|
174 |
+
ranks = np.zeros(scores_i2t.shape[0])
|
175 |
+
for index,score in enumerate(scores_i2t):
|
176 |
+
inds = np.argsort(score)[::-1]
|
177 |
+
# Score
|
178 |
+
rank = 1e20
|
179 |
+
for i in img2txt[index]:
|
180 |
+
tmp = np.where(inds == i)[0][0]
|
181 |
+
if tmp < rank:
|
182 |
+
rank = tmp
|
183 |
+
ranks[index] = rank
|
184 |
+
|
185 |
+
# Compute metrics
|
186 |
+
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
187 |
+
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
188 |
+
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
189 |
+
|
190 |
+
#Text->Images
|
191 |
+
ranks = np.zeros(scores_t2i.shape[0])
|
192 |
+
|
193 |
+
for index,score in enumerate(scores_t2i):
|
194 |
+
inds = np.argsort(score)[::-1]
|
195 |
+
ranks[index] = np.where(inds == txt2img[index])[0][0]
|
196 |
+
|
197 |
+
# Compute metrics
|
198 |
+
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
199 |
+
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
200 |
+
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
201 |
+
|
202 |
+
tr_mean = (tr1 + tr5 + tr10) / 3
|
203 |
+
ir_mean = (ir1 + ir5 + ir10) / 3
|
204 |
+
r_mean = (tr_mean + ir_mean) / 2
|
205 |
+
|
206 |
+
eval_result = {'txt_r1': tr1,
|
207 |
+
'txt_r5': tr5,
|
208 |
+
'txt_r10': tr10,
|
209 |
+
'txt_r_mean': tr_mean,
|
210 |
+
'img_r1': ir1,
|
211 |
+
'img_r5': ir5,
|
212 |
+
'img_r10': ir10,
|
213 |
+
'img_r_mean': ir_mean,
|
214 |
+
'r_mean': r_mean}
|
215 |
+
return eval_result
|
216 |
+
|
217 |
+
|
218 |
+
def main(args, config):
|
219 |
+
utils.init_distributed_mode(args)
|
220 |
+
|
221 |
+
device = torch.device(args.device)
|
222 |
+
|
223 |
+
# fix the seed for reproducibility
|
224 |
+
seed = args.seed + utils.get_rank()
|
225 |
+
torch.manual_seed(seed)
|
226 |
+
np.random.seed(seed)
|
227 |
+
random.seed(seed)
|
228 |
+
cudnn.benchmark = True
|
229 |
+
|
230 |
+
#### Dataset ####
|
231 |
+
print("Creating retrieval dataset")
|
232 |
+
train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
|
233 |
+
|
234 |
+
if args.distributed:
|
235 |
+
num_tasks = utils.get_world_size()
|
236 |
+
global_rank = utils.get_rank()
|
237 |
+
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
|
238 |
+
else:
|
239 |
+
samplers = [None, None, None]
|
240 |
+
|
241 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
242 |
+
batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
|
243 |
+
num_workers=[4,4,4],
|
244 |
+
is_trains=[True, False, False],
|
245 |
+
collate_fns=[None,None,None])
|
246 |
+
|
247 |
+
|
248 |
+
#### Model ####
|
249 |
+
print("Creating model")
|
250 |
+
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
251 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
252 |
+
queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
|
253 |
+
|
254 |
+
model = model.to(device)
|
255 |
+
|
256 |
+
model_without_ddp = model
|
257 |
+
if args.distributed:
|
258 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
259 |
+
model_without_ddp = model.module
|
260 |
+
|
261 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
262 |
+
|
263 |
+
best = 0
|
264 |
+
best_epoch = 0
|
265 |
+
|
266 |
+
print("Start training")
|
267 |
+
start_time = time.time()
|
268 |
+
|
269 |
+
for epoch in range(0, config['max_epoch']):
|
270 |
+
if not args.evaluate:
|
271 |
+
if args.distributed:
|
272 |
+
train_loader.sampler.set_epoch(epoch)
|
273 |
+
|
274 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
275 |
+
|
276 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
277 |
+
|
278 |
+
score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
|
279 |
+
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
|
280 |
+
|
281 |
+
if utils.is_main_process():
|
282 |
+
|
283 |
+
val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
|
284 |
+
print(val_result)
|
285 |
+
|
286 |
+
if val_result['r_mean']>best:
|
287 |
+
save_obj = {
|
288 |
+
'model': model_without_ddp.state_dict(),
|
289 |
+
'optimizer': optimizer.state_dict(),
|
290 |
+
'config': config,
|
291 |
+
'epoch': epoch,
|
292 |
+
}
|
293 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
294 |
+
best = val_result['r_mean']
|
295 |
+
best_epoch = epoch
|
296 |
+
|
297 |
+
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
|
298 |
+
print(test_result)
|
299 |
+
|
300 |
+
if args.evaluate:
|
301 |
+
log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
|
302 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
303 |
+
}
|
304 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
305 |
+
f.write(json.dumps(log_stats) + "\n")
|
306 |
+
else:
|
307 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
308 |
+
**{f'val_{k}': v for k, v in val_result.items()},
|
309 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
310 |
+
'epoch': epoch,
|
311 |
+
'best_epoch': best_epoch,
|
312 |
+
}
|
313 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
314 |
+
f.write(json.dumps(log_stats) + "\n")
|
315 |
+
|
316 |
+
if args.evaluate:
|
317 |
+
break
|
318 |
+
|
319 |
+
dist.barrier()
|
320 |
+
torch.cuda.empty_cache()
|
321 |
+
|
322 |
+
total_time = time.time() - start_time
|
323 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
324 |
+
print('Training time {}'.format(total_time_str))
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
parser = argparse.ArgumentParser()
|
329 |
+
parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
|
330 |
+
parser.add_argument('--output_dir', default='output/Retrieval_flickr')
|
331 |
+
parser.add_argument('--evaluate', action='store_true')
|
332 |
+
parser.add_argument('--device', default='cuda')
|
333 |
+
parser.add_argument('--seed', default=42, type=int)
|
334 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
335 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
336 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
340 |
+
|
341 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
342 |
+
|
343 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
344 |
+
|
345 |
+
main(args, config)
|
train_vqa.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
from models.blip_vqa import blip_vqa
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.vqa_dataset import vqa_collate_fn
|
30 |
+
from data.utils import save_result
|
31 |
+
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
|
44 |
+
for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
45 |
+
image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
|
46 |
+
|
47 |
+
loss = model(image, question, answer, train=True, n=n, weights=weights)
|
48 |
+
|
49 |
+
optimizer.zero_grad()
|
50 |
+
loss.backward()
|
51 |
+
optimizer.step()
|
52 |
+
|
53 |
+
metric_logger.update(loss=loss.item())
|
54 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
55 |
+
|
56 |
+
# gather the stats from all processes
|
57 |
+
metric_logger.synchronize_between_processes()
|
58 |
+
print("Averaged stats:", metric_logger.global_avg())
|
59 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
60 |
+
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def evaluation(model, data_loader, device, config) :
|
64 |
+
# test
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
68 |
+
header = 'Generate VQA test result:'
|
69 |
+
print_freq = 50
|
70 |
+
|
71 |
+
result = []
|
72 |
+
|
73 |
+
if config['inference']=='rank':
|
74 |
+
answer_list = data_loader.dataset.answer_list
|
75 |
+
answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
|
76 |
+
answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
|
77 |
+
|
78 |
+
for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
79 |
+
image = image.to(device,non_blocking=True)
|
80 |
+
|
81 |
+
if config['inference']=='generate':
|
82 |
+
answers = model(image, question, train=False, inference='generate')
|
83 |
+
|
84 |
+
for answer, ques_id in zip(answers, question_id):
|
85 |
+
ques_id = int(ques_id.item())
|
86 |
+
result.append({"question_id":ques_id, "answer":answer})
|
87 |
+
|
88 |
+
elif config['inference']=='rank':
|
89 |
+
answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
|
90 |
+
|
91 |
+
for ques_id, answer_id in zip(question_id, answer_ids):
|
92 |
+
result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
def main(args, config):
|
98 |
+
utils.init_distributed_mode(args)
|
99 |
+
|
100 |
+
device = torch.device(args.device)
|
101 |
+
|
102 |
+
# fix the seed for reproducibility
|
103 |
+
seed = args.seed + utils.get_rank()
|
104 |
+
torch.manual_seed(seed)
|
105 |
+
np.random.seed(seed)
|
106 |
+
random.seed(seed)
|
107 |
+
cudnn.benchmark = True
|
108 |
+
|
109 |
+
#### Dataset ####
|
110 |
+
print("Creating vqa datasets")
|
111 |
+
datasets = create_dataset('vqa', config)
|
112 |
+
|
113 |
+
if args.distributed:
|
114 |
+
num_tasks = utils.get_world_size()
|
115 |
+
global_rank = utils.get_rank()
|
116 |
+
samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
|
117 |
+
else:
|
118 |
+
samplers = [None, None]
|
119 |
+
|
120 |
+
train_loader, test_loader = create_loader(datasets,samplers,
|
121 |
+
batch_size=[config['batch_size_train'],config['batch_size_test']],
|
122 |
+
num_workers=[4,4],is_trains=[True, False],
|
123 |
+
collate_fns=[vqa_collate_fn,None])
|
124 |
+
#### Model ####
|
125 |
+
print("Creating model")
|
126 |
+
model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
|
127 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
128 |
+
|
129 |
+
model = model.to(device)
|
130 |
+
|
131 |
+
model_without_ddp = model
|
132 |
+
if args.distributed:
|
133 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
134 |
+
model_without_ddp = model.module
|
135 |
+
|
136 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
137 |
+
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
print("Start training")
|
142 |
+
start_time = time.time()
|
143 |
+
for epoch in range(0, config['max_epoch']):
|
144 |
+
if not args.evaluate:
|
145 |
+
if args.distributed:
|
146 |
+
train_loader.sampler.set_epoch(epoch)
|
147 |
+
|
148 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
149 |
+
|
150 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
151 |
+
|
152 |
+
else:
|
153 |
+
break
|
154 |
+
|
155 |
+
if utils.is_main_process():
|
156 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
157 |
+
'epoch': epoch,
|
158 |
+
}
|
159 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
160 |
+
f.write(json.dumps(log_stats) + "\n")
|
161 |
+
|
162 |
+
save_obj = {
|
163 |
+
'model': model_without_ddp.state_dict(),
|
164 |
+
'optimizer': optimizer.state_dict(),
|
165 |
+
'config': config,
|
166 |
+
'epoch': epoch,
|
167 |
+
}
|
168 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
|
169 |
+
|
170 |
+
dist.barrier()
|
171 |
+
|
172 |
+
vqa_result = evaluation(model_without_ddp, test_loader, device, config)
|
173 |
+
result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
|
174 |
+
|
175 |
+
total_time = time.time() - start_time
|
176 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
+
print('Training time {}'.format(total_time_str))
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('--config', default='./configs/vqa.yaml')
|
184 |
+
parser.add_argument('--output_dir', default='output/VQA')
|
185 |
+
parser.add_argument('--evaluate', action='store_true')
|
186 |
+
parser.add_argument('--device', default='cuda')
|
187 |
+
parser.add_argument('--seed', default=42, type=int)
|
188 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
189 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
190 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
194 |
+
|
195 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
196 |
+
|
197 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
198 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
199 |
+
|
200 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
201 |
+
|
202 |
+
main(args, config)
|
transform/randaugment.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
## aug functions
|
6 |
+
def identity_func(img):
|
7 |
+
return img
|
8 |
+
|
9 |
+
|
10 |
+
def autocontrast_func(img, cutoff=0):
|
11 |
+
'''
|
12 |
+
same output as PIL.ImageOps.autocontrast
|
13 |
+
'''
|
14 |
+
n_bins = 256
|
15 |
+
|
16 |
+
def tune_channel(ch):
|
17 |
+
n = ch.size
|
18 |
+
cut = cutoff * n // 100
|
19 |
+
if cut == 0:
|
20 |
+
high, low = ch.max(), ch.min()
|
21 |
+
else:
|
22 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
23 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
24 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
25 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
26 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
27 |
+
if high <= low:
|
28 |
+
table = np.arange(n_bins)
|
29 |
+
else:
|
30 |
+
scale = (n_bins - 1) / (high - low)
|
31 |
+
offset = -low * scale
|
32 |
+
table = np.arange(n_bins) * scale + offset
|
33 |
+
table[table < 0] = 0
|
34 |
+
table[table > n_bins - 1] = n_bins - 1
|
35 |
+
table = table.clip(0, 255).astype(np.uint8)
|
36 |
+
return table[ch]
|
37 |
+
|
38 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
39 |
+
out = cv2.merge(channels)
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
def equalize_func(img):
|
44 |
+
'''
|
45 |
+
same output as PIL.ImageOps.equalize
|
46 |
+
PIL's implementation is different from cv2.equalize
|
47 |
+
'''
|
48 |
+
n_bins = 256
|
49 |
+
|
50 |
+
def tune_channel(ch):
|
51 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
52 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
53 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
54 |
+
if step == 0: return ch
|
55 |
+
n = np.empty_like(hist)
|
56 |
+
n[0] = step // 2
|
57 |
+
n[1:] = hist[:-1]
|
58 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
59 |
+
return table[ch]
|
60 |
+
|
61 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
62 |
+
out = cv2.merge(channels)
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
67 |
+
'''
|
68 |
+
like PIL, rotate by degree, not radians
|
69 |
+
'''
|
70 |
+
H, W = img.shape[0], img.shape[1]
|
71 |
+
center = W / 2, H / 2
|
72 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
73 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
def solarize_func(img, thresh=128):
|
78 |
+
'''
|
79 |
+
same output as PIL.ImageOps.posterize
|
80 |
+
'''
|
81 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
82 |
+
table = table.clip(0, 255).astype(np.uint8)
|
83 |
+
out = table[img]
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
def color_func(img, factor):
|
88 |
+
'''
|
89 |
+
same output as PIL.ImageEnhance.Color
|
90 |
+
'''
|
91 |
+
## implementation according to PIL definition, quite slow
|
92 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
93 |
+
# out = blend(degenerate, img, factor)
|
94 |
+
# M = (
|
95 |
+
# np.eye(3) * factor
|
96 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
97 |
+
# )[np.newaxis, np.newaxis, :]
|
98 |
+
M = (
|
99 |
+
np.float32([
|
100 |
+
[0.886, -0.114, -0.114],
|
101 |
+
[-0.587, 0.413, -0.587],
|
102 |
+
[-0.299, -0.299, 0.701]]) * factor
|
103 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
104 |
+
)
|
105 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
106 |
+
return out
|
107 |
+
|
108 |
+
|
109 |
+
def contrast_func(img, factor):
|
110 |
+
"""
|
111 |
+
same output as PIL.ImageEnhance.Contrast
|
112 |
+
"""
|
113 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
114 |
+
table = np.array([(
|
115 |
+
el - mean) * factor + mean
|
116 |
+
for el in range(256)
|
117 |
+
]).clip(0, 255).astype(np.uint8)
|
118 |
+
out = table[img]
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
def brightness_func(img, factor):
|
123 |
+
'''
|
124 |
+
same output as PIL.ImageEnhance.Contrast
|
125 |
+
'''
|
126 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
127 |
+
out = table[img]
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
def sharpness_func(img, factor):
|
132 |
+
'''
|
133 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
134 |
+
areas are same
|
135 |
+
'''
|
136 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
137 |
+
kernel[1][1] = 5
|
138 |
+
kernel /= 13
|
139 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
140 |
+
if factor == 0.0:
|
141 |
+
out = degenerate
|
142 |
+
elif factor == 1.0:
|
143 |
+
out = img
|
144 |
+
else:
|
145 |
+
out = img.astype(np.float32)
|
146 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
147 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
148 |
+
out = out.astype(np.uint8)
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
153 |
+
H, W = img.shape[0], img.shape[1]
|
154 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
155 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
160 |
+
'''
|
161 |
+
same output as PIL.Image.transform
|
162 |
+
'''
|
163 |
+
H, W = img.shape[0], img.shape[1]
|
164 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
165 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
166 |
+
return out
|
167 |
+
|
168 |
+
|
169 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
170 |
+
'''
|
171 |
+
same output as PIL.Image.transform
|
172 |
+
'''
|
173 |
+
H, W = img.shape[0], img.shape[1]
|
174 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
175 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
176 |
+
return out
|
177 |
+
|
178 |
+
|
179 |
+
def posterize_func(img, bits):
|
180 |
+
'''
|
181 |
+
same output as PIL.ImageOps.posterize
|
182 |
+
'''
|
183 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
184 |
+
return out
|
185 |
+
|
186 |
+
|
187 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
188 |
+
H, W = img.shape[0], img.shape[1]
|
189 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
190 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
195 |
+
replace = np.array(replace, dtype=np.uint8)
|
196 |
+
H, W = img.shape[0], img.shape[1]
|
197 |
+
rh, rw = np.random.random(2)
|
198 |
+
pad_size = pad_size // 2
|
199 |
+
ch, cw = int(rh * H), int(rw * W)
|
200 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
201 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
202 |
+
out = img.copy()
|
203 |
+
out[x1:x2, y1:y2, :] = replace
|
204 |
+
return out
|
205 |
+
|
206 |
+
|
207 |
+
### level to args
|
208 |
+
def enhance_level_to_args(MAX_LEVEL):
|
209 |
+
def level_to_args(level):
|
210 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
211 |
+
return level_to_args
|
212 |
+
|
213 |
+
|
214 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
215 |
+
def level_to_args(level):
|
216 |
+
level = (level / MAX_LEVEL) * 0.3
|
217 |
+
if np.random.random() > 0.5: level = -level
|
218 |
+
return (level, replace_value)
|
219 |
+
|
220 |
+
return level_to_args
|
221 |
+
|
222 |
+
|
223 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
224 |
+
def level_to_args(level):
|
225 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
226 |
+
if np.random.random() > 0.5: level = -level
|
227 |
+
return (level, replace_value)
|
228 |
+
|
229 |
+
return level_to_args
|
230 |
+
|
231 |
+
|
232 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
233 |
+
def level_to_args(level):
|
234 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
235 |
+
return (level, replace_value)
|
236 |
+
|
237 |
+
return level_to_args
|
238 |
+
|
239 |
+
|
240 |
+
def solarize_level_to_args(MAX_LEVEL):
|
241 |
+
def level_to_args(level):
|
242 |
+
level = int((level / MAX_LEVEL) * 256)
|
243 |
+
return (level, )
|
244 |
+
return level_to_args
|
245 |
+
|
246 |
+
|
247 |
+
def none_level_to_args(level):
|
248 |
+
return ()
|
249 |
+
|
250 |
+
|
251 |
+
def posterize_level_to_args(MAX_LEVEL):
|
252 |
+
def level_to_args(level):
|
253 |
+
level = int((level / MAX_LEVEL) * 4)
|
254 |
+
return (level, )
|
255 |
+
return level_to_args
|
256 |
+
|
257 |
+
|
258 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
259 |
+
def level_to_args(level):
|
260 |
+
level = (level / MAX_LEVEL) * 30
|
261 |
+
if np.random.random() < 0.5:
|
262 |
+
level = -level
|
263 |
+
return (level, replace_value)
|
264 |
+
|
265 |
+
return level_to_args
|
266 |
+
|
267 |
+
|
268 |
+
func_dict = {
|
269 |
+
'Identity': identity_func,
|
270 |
+
'AutoContrast': autocontrast_func,
|
271 |
+
'Equalize': equalize_func,
|
272 |
+
'Rotate': rotate_func,
|
273 |
+
'Solarize': solarize_func,
|
274 |
+
'Color': color_func,
|
275 |
+
'Contrast': contrast_func,
|
276 |
+
'Brightness': brightness_func,
|
277 |
+
'Sharpness': sharpness_func,
|
278 |
+
'ShearX': shear_x_func,
|
279 |
+
'TranslateX': translate_x_func,
|
280 |
+
'TranslateY': translate_y_func,
|
281 |
+
'Posterize': posterize_func,
|
282 |
+
'ShearY': shear_y_func,
|
283 |
+
}
|
284 |
+
|
285 |
+
translate_const = 10
|
286 |
+
MAX_LEVEL = 10
|
287 |
+
replace_value = (128, 128, 128)
|
288 |
+
arg_dict = {
|
289 |
+
'Identity': none_level_to_args,
|
290 |
+
'AutoContrast': none_level_to_args,
|
291 |
+
'Equalize': none_level_to_args,
|
292 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
293 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
294 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
295 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
296 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
297 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
298 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
299 |
+
'TranslateX': translate_level_to_args(
|
300 |
+
translate_const, MAX_LEVEL, replace_value
|
301 |
+
),
|
302 |
+
'TranslateY': translate_level_to_args(
|
303 |
+
translate_const, MAX_LEVEL, replace_value
|
304 |
+
),
|
305 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
306 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
class RandomAugment(object):
|
311 |
+
|
312 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
313 |
+
self.N = N
|
314 |
+
self.M = M
|
315 |
+
self.isPIL = isPIL
|
316 |
+
if augs:
|
317 |
+
self.augs = augs
|
318 |
+
else:
|
319 |
+
self.augs = list(arg_dict.keys())
|
320 |
+
|
321 |
+
def get_random_ops(self):
|
322 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
323 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
324 |
+
|
325 |
+
def __call__(self, img):
|
326 |
+
if self.isPIL:
|
327 |
+
img = np.array(img)
|
328 |
+
ops = self.get_random_ops()
|
329 |
+
for name, prob, level in ops:
|
330 |
+
if np.random.random() > prob:
|
331 |
+
continue
|
332 |
+
args = arg_dict[name](level)
|
333 |
+
img = func_dict[name](img, *args)
|
334 |
+
return img
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == '__main__':
|
338 |
+
a = RandomAugment()
|
339 |
+
img = np.random.randn(32, 32, 3)
|
340 |
+
a(img)
|
utils.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
3 |
+
"""Decay the learning rate"""
|
4 |
+
lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
5 |
+
for param_group in optimizer.param_groups:
|
6 |
+
param_group['lr'] = lr
|
7 |
+
|
8 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
9 |
+
"""Warmup the learning rate"""
|
10 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
|
11 |
+
for param_group in optimizer.param_groups:
|
12 |
+
param_group['lr'] = lr
|
13 |
+
|
14 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
15 |
+
"""Decay the learning rate"""
|
16 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
17 |
+
for param_group in optimizer.param_groups:
|
18 |
+
param_group['lr'] = lr
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import io
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
from collections import defaultdict, deque
|
25 |
+
import datetime
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.distributed as dist
|
29 |
+
|
30 |
+
class SmoothedValue(object):
|
31 |
+
"""Track a series of values and provide access to smoothed values over a
|
32 |
+
window or the global series average.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, window_size=20, fmt=None):
|
36 |
+
if fmt is None:
|
37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
38 |
+
self.deque = deque(maxlen=window_size)
|
39 |
+
self.total = 0.0
|
40 |
+
self.count = 0
|
41 |
+
self.fmt = fmt
|
42 |
+
|
43 |
+
def update(self, value, n=1):
|
44 |
+
self.deque.append(value)
|
45 |
+
self.count += n
|
46 |
+
self.total += value * n
|
47 |
+
|
48 |
+
def synchronize_between_processes(self):
|
49 |
+
"""
|
50 |
+
Warning: does not synchronize the deque!
|
51 |
+
"""
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return
|
54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
55 |
+
dist.barrier()
|
56 |
+
dist.all_reduce(t)
|
57 |
+
t = t.tolist()
|
58 |
+
self.count = int(t[0])
|
59 |
+
self.total = t[1]
|
60 |
+
|
61 |
+
@property
|
62 |
+
def median(self):
|
63 |
+
d = torch.tensor(list(self.deque))
|
64 |
+
return d.median().item()
|
65 |
+
|
66 |
+
@property
|
67 |
+
def avg(self):
|
68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
69 |
+
return d.mean().item()
|
70 |
+
|
71 |
+
@property
|
72 |
+
def global_avg(self):
|
73 |
+
return self.total / self.count
|
74 |
+
|
75 |
+
@property
|
76 |
+
def max(self):
|
77 |
+
return max(self.deque)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def value(self):
|
81 |
+
return self.deque[-1]
|
82 |
+
|
83 |
+
def __str__(self):
|
84 |
+
return self.fmt.format(
|
85 |
+
median=self.median,
|
86 |
+
avg=self.avg,
|
87 |
+
global_avg=self.global_avg,
|
88 |
+
max=self.max,
|
89 |
+
value=self.value)
|
90 |
+
|
91 |
+
|
92 |
+
class MetricLogger(object):
|
93 |
+
def __init__(self, delimiter="\t"):
|
94 |
+
self.meters = defaultdict(SmoothedValue)
|
95 |
+
self.delimiter = delimiter
|
96 |
+
|
97 |
+
def update(self, **kwargs):
|
98 |
+
for k, v in kwargs.items():
|
99 |
+
if isinstance(v, torch.Tensor):
|
100 |
+
v = v.item()
|
101 |
+
assert isinstance(v, (float, int))
|
102 |
+
self.meters[k].update(v)
|
103 |
+
|
104 |
+
def __getattr__(self, attr):
|
105 |
+
if attr in self.meters:
|
106 |
+
return self.meters[attr]
|
107 |
+
if attr in self.__dict__:
|
108 |
+
return self.__dict__[attr]
|
109 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
110 |
+
type(self).__name__, attr))
|
111 |
+
|
112 |
+
def __str__(self):
|
113 |
+
loss_str = []
|
114 |
+
for name, meter in self.meters.items():
|
115 |
+
loss_str.append(
|
116 |
+
"{}: {}".format(name, str(meter))
|
117 |
+
)
|
118 |
+
return self.delimiter.join(loss_str)
|
119 |
+
|
120 |
+
def global_avg(self):
|
121 |
+
loss_str = []
|
122 |
+
for name, meter in self.meters.items():
|
123 |
+
loss_str.append(
|
124 |
+
"{}: {:.4f}".format(name, meter.global_avg)
|
125 |
+
)
|
126 |
+
return self.delimiter.join(loss_str)
|
127 |
+
|
128 |
+
def synchronize_between_processes(self):
|
129 |
+
for meter in self.meters.values():
|
130 |
+
meter.synchronize_between_processes()
|
131 |
+
|
132 |
+
def add_meter(self, name, meter):
|
133 |
+
self.meters[name] = meter
|
134 |
+
|
135 |
+
def log_every(self, iterable, print_freq, header=None):
|
136 |
+
i = 0
|
137 |
+
if not header:
|
138 |
+
header = ''
|
139 |
+
start_time = time.time()
|
140 |
+
end = time.time()
|
141 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
142 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
143 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
144 |
+
log_msg = [
|
145 |
+
header,
|
146 |
+
'[{0' + space_fmt + '}/{1}]',
|
147 |
+
'eta: {eta}',
|
148 |
+
'{meters}',
|
149 |
+
'time: {time}',
|
150 |
+
'data: {data}'
|
151 |
+
]
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
log_msg.append('max mem: {memory:.0f}')
|
154 |
+
log_msg = self.delimiter.join(log_msg)
|
155 |
+
MB = 1024.0 * 1024.0
|
156 |
+
for obj in iterable:
|
157 |
+
data_time.update(time.time() - end)
|
158 |
+
yield obj
|
159 |
+
iter_time.update(time.time() - end)
|
160 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
161 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
162 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
print(log_msg.format(
|
165 |
+
i, len(iterable), eta=eta_string,
|
166 |
+
meters=str(self),
|
167 |
+
time=str(iter_time), data=str(data_time),
|
168 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
169 |
+
else:
|
170 |
+
print(log_msg.format(
|
171 |
+
i, len(iterable), eta=eta_string,
|
172 |
+
meters=str(self),
|
173 |
+
time=str(iter_time), data=str(data_time)))
|
174 |
+
i += 1
|
175 |
+
end = time.time()
|
176 |
+
total_time = time.time() - start_time
|
177 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
178 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
179 |
+
header, total_time_str, total_time / len(iterable)))
|
180 |
+
|
181 |
+
|
182 |
+
class AttrDict(dict):
|
183 |
+
def __init__(self, *args, **kwargs):
|
184 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
185 |
+
self.__dict__ = self
|
186 |
+
|
187 |
+
|
188 |
+
def compute_acc(logits, label, reduction='mean'):
|
189 |
+
ret = (torch.argmax(logits, dim=1) == label).float()
|
190 |
+
if reduction == 'none':
|
191 |
+
return ret.detach()
|
192 |
+
elif reduction == 'mean':
|
193 |
+
return ret.mean().item()
|
194 |
+
|
195 |
+
def compute_n_params(model, return_str=True):
|
196 |
+
tot = 0
|
197 |
+
for p in model.parameters():
|
198 |
+
w = 1
|
199 |
+
for x in p.shape:
|
200 |
+
w *= x
|
201 |
+
tot += w
|
202 |
+
if return_str:
|
203 |
+
if tot >= 1e6:
|
204 |
+
return '{:.1f}M'.format(tot / 1e6)
|
205 |
+
else:
|
206 |
+
return '{:.1f}K'.format(tot / 1e3)
|
207 |
+
else:
|
208 |
+
return tot
|
209 |
+
|
210 |
+
def setup_for_distributed(is_master):
|
211 |
+
"""
|
212 |
+
This function disables printing when not in master process
|
213 |
+
"""
|
214 |
+
import builtins as __builtin__
|
215 |
+
builtin_print = __builtin__.print
|
216 |
+
|
217 |
+
def print(*args, **kwargs):
|
218 |
+
force = kwargs.pop('force', False)
|
219 |
+
if is_master or force:
|
220 |
+
builtin_print(*args, **kwargs)
|
221 |
+
|
222 |
+
__builtin__.print = print
|
223 |
+
|
224 |
+
|
225 |
+
def is_dist_avail_and_initialized():
|
226 |
+
if not dist.is_available():
|
227 |
+
return False
|
228 |
+
if not dist.is_initialized():
|
229 |
+
return False
|
230 |
+
return True
|
231 |
+
|
232 |
+
|
233 |
+
def get_world_size():
|
234 |
+
if not is_dist_avail_and_initialized():
|
235 |
+
return 1
|
236 |
+
return dist.get_world_size()
|
237 |
+
|
238 |
+
|
239 |
+
def get_rank():
|
240 |
+
if not is_dist_avail_and_initialized():
|
241 |
+
return 0
|
242 |
+
return dist.get_rank()
|
243 |
+
|
244 |
+
|
245 |
+
def is_main_process():
|
246 |
+
return get_rank() == 0
|
247 |
+
|
248 |
+
|
249 |
+
def save_on_master(*args, **kwargs):
|
250 |
+
if is_main_process():
|
251 |
+
torch.save(*args, **kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def init_distributed_mode(args):
|
255 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
256 |
+
args.rank = int(os.environ["RANK"])
|
257 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
258 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
259 |
+
elif 'SLURM_PROCID' in os.environ:
|
260 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
261 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
262 |
+
else:
|
263 |
+
print('Not using distributed mode')
|
264 |
+
args.distributed = False
|
265 |
+
return
|
266 |
+
|
267 |
+
args.distributed = True
|
268 |
+
|
269 |
+
torch.cuda.set_device(args.gpu)
|
270 |
+
args.dist_backend = 'nccl'
|
271 |
+
print('| distributed init (rank {}, word {}): {}'.format(
|
272 |
+
args.rank, args.world_size, args.dist_url), flush=True)
|
273 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
274 |
+
world_size=args.world_size, rank=args.rank)
|
275 |
+
torch.distributed.barrier()
|
276 |
+
setup_for_distributed(args.rank == 0)
|
277 |
+
|
278 |
+
|