File size: 4,430 Bytes
a166479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import sys
import torch.utils.data as data
import torch
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
import random
from bert.tokenization_bert import BertTokenizer
import h5py
from refer.refer import REFER
from args import get_parser
# Dataset configuration initialization
parser = get_parser()
args = parser.parse_args()
class ReferDataset(data.Dataset):
def __init__(self,
args,
image_transforms=None,
target_transforms=None,
split='train',
eval_mode=False):
self.classes = []
self.image_transforms = image_transforms
self.target_transform = target_transforms
self.split = split
self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
self.max_tokens = 20
ref_ids = self.refer.getRefIds(split=self.split)
img_ids = self.refer.getImgIds(ref_ids)
all_imgs = self.refer.Imgs
self.imgs = list(all_imgs[i] for i in img_ids)
self.ref_ids = ref_ids
self.input_ids = []
self.attention_masks = []
self.raw_sentences = []
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
self.eval_mode = eval_mode
# if we are testing on a dataset, test all sentences of an object;
# o/w, we are validating during training, randomly sample one sentence for efficiency
for r in ref_ids:
ref = self.refer.Refs[r]
sentences_for_ref = []
attentions_for_ref = []
raw_sentences_for_ref = []
for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])):
sentence_raw = el['raw']
attention_mask = [0] * self.max_tokens
padded_input_ids = [0] * self.max_tokens
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
# truncation of tokens
input_ids = input_ids[:self.max_tokens]
padded_input_ids[:len(input_ids)] = input_ids
attention_mask[:len(input_ids)] = [1]*len(input_ids)
sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
raw_sentences_for_ref.append(sentence_raw)
self.input_ids.append(sentences_for_ref)
self.attention_masks.append(attentions_for_ref)
self.raw_sentences.append(raw_sentences_for_ref)
def get_classes(self):
return self.classes
def __len__(self):
return len(self.ref_ids)
def __getitem__(self, index):
this_ref_id = self.ref_ids[index]
this_img_id = self.refer.getImgIds(this_ref_id)
this_img = self.refer.Imgs[this_img_id[0]]
img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])).convert("RGB")
orig_img = np.array(img)
#orig_shape = np.array(img).shape
ref = self.refer.loadRefs(this_ref_id)
ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
annot = np.zeros(ref_mask.shape)
annot[ref_mask == 1] = 1
annot = Image.fromarray(annot.astype(np.uint8), mode="P")
if self.image_transforms is not None:
# resize, from PIL to tensor, and mean and std normalization
img, target = self.image_transforms(img, annot)
if self.eval_mode:
embedding = []
att = []
for s in range(len(self.input_ids[index])):
e = self.input_ids[index][s]
a = self.attention_masks[index][s]
embedding.append(e.unsqueeze(-1))
att.append(a.unsqueeze(-1))
tensor_embeddings = torch.cat(embedding, dim=-1)
attention_mask = torch.cat(att, dim=-1)
raw_sentence = self.raw_sentences[index]
else:
choice_sent = np.random.choice(len(self.input_ids[index]))
tensor_embeddings = self.input_ids[index][choice_sent]
attention_mask = self.attention_masks[index][choice_sent]
return img, target, tensor_embeddings, attention_mask, raw_sentence, this_img['file_name'], orig_img
|