chanycha's picture
update
bab4b6d
from typing import Optional, List
import os
import json
import requests
import functools
from io import BytesIO
from pathlib import Path
from urllib3 import disable_warnings
from urllib3.exceptions import InsecureRequestWarning
import torch
import torchvision
from torch import Tensor
from torch.nn.modules import Module
from torch.utils.data import Dataset, Subset, DataLoader
# from torchtext.datasets import IMDB
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import ViltForQuestionAnswering, ViltProcessor
from tqdm import tqdm
from PIL import Image
# datasets
class ImageNetDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.img_dir = os.path.join(self.root_dir, 'samples/')
self.label_dir = os.path.join(self.root_dir, 'imagenet_class_index.json')
with open(self.label_dir) as json_data:
self.idx_to_labels = json.load(json_data)
self.img_names = os.listdir(self.img_dir)
self.img_names.sort()
self.transform = transform
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_path).convert('RGB')
label = idx
if self.transform:
image = self.transform(image)
return image, label
def idx_to_label(self, idx):
return self.idx_to_labels[str(idx)][1]
def get_imagenet_dataset(
transform,
subset_size: int=100, # ignored if indices is not None
root_dir="./data/ImageNet",
indices: Optional[List[int]]=None,
):
os.chdir(Path(__file__).parent) # ensure path
dataset = ImageNetDataset(root_dir=root_dir, transform=transform)
if indices is not None:
return Subset(dataset, indices=indices)
indices = list(range(len(dataset)))
subset = Subset(dataset, indices=indices[:subset_size])
return subset
class IMDBDataset(Dataset):
def __init__(self, split='test'):
super().__init__()
data_iter = IMDB(split=split)
self.annotations = [(line, label-1) for label, line in tqdm(data_iter)]
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
return self.annotations[idx]
def get_imdb_dataset(split='test'):
return IMDBDataset(split=split)
disable_warnings(InsecureRequestWarning)
class VQADataset(Dataset):
def __init__(self):
super().__init__()
res = requests.get('https://visualqa.org/balanced_data.json')
self.annotations = eval(res.text)
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
data = self.annotations[idx]
if isinstance(data['original_image'], str):
print(f"Requesting {data['original_image']}...")
res = requests.get(data['original_image'], verify=False)
img = Image.open(BytesIO(res.content)).convert('RGB')
data['original_image'] = img
return data['original_image'], data['question'], data['original_answer']
def get_vqa_dataset():
return VQADataset()
# models
def get_torchvision_model(model_name):
weights = torchvision.models.get_model_weights(model_name).DEFAULT
model = torchvision.models.get_model(model_name, weights=weights).eval()
transform = weights.transforms()
return model, transform
class Bert(BertForSequenceClassification):
def forward(self, input_ids, token_type_ids, attention_mask):
return super().forward(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask
).logits
def get_bert_model(model_name, num_labels):
return Bert.from_pretrained(model_name, num_labels=num_labels)
class Vilt(ViltForQuestionAnswering):
def forward(
self,
pixel_values,
input_ids,
token_type_ids,
attention_mask,
pixel_mask,
):
return super().forward(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
).logits
def get_vilt_model(model_name):
return Vilt.from_pretrained(model_name)
# utils
img_to_np = lambda img: img.permute(1, 2, 0).detach().numpy()
def denormalize_image(inputs, mean, std):
return img_to_np(
inputs
* Tensor(std)[:, None, None]
+ Tensor(mean)[:, None, None]
)
def bert_collate_fn(batch, tokenizer=None):
inputs = tokenizer(
[d[0] for d in batch],
padding=True,
truncation=True,
return_tensors='pt',
)
labels = torch.tensor([d[1] for d in batch])
return tuple(inputs.values()), labels
def get_bert_tokenizer(model_name):
return BertTokenizer.from_pretrained(model_name)
def get_vilt_processor(model_name):
return ViltProcessor.from_pretrained(model_name)
def vilt_collate_fn(batch, processor=None, label2id=None):
imgs = [d[0] for d in batch]
qsts = [d[1] for d in batch]
inputs = processor(
images=imgs,
text=qsts,
padding=True,
truncation=True,
return_tensors='pt',
)
labels = torch.tensor([label2id[d[2]] for d in batch])
return (
inputs['pixel_values'],
inputs['input_ids'],
inputs['token_type_ids'],
inputs['attention_mask'],
inputs['pixel_mask'],
labels,
)
def load_model_and_dataloader_for_tutorial(modality, device):
if modality == 'image':
model, transform = get_torchvision_model('resnet18')
model = model.to(device)
model.eval()
dataset = get_imagenet_dataset(transform)
loader = DataLoader(dataset, batch_size=8, shuffle=False)
return model, loader, transform
elif modality == 'text':
model = get_bert_model('fabriceyhc/bert-base-uncased-imdb', num_labels=2)
model = model.to(device)
model.eval()
dataset = get_imdb_dataset(split='test')
tokenizer = get_bert_tokenizer('fabriceyhc/bert-base-uncased-imdb')
loader = DataLoader(
dataset,
batch_size=8,
shuffle=False,
collate_fn=functools.partial(bert_collate_fn, tokenizer=tokenizer)
)
return model, loader, tokenizer
elif modality == ('image', 'text'):
model = get_vilt_model('dandelin/vilt-b32-finetuned-vqa')
model.to(device)
model.eval()
dataset = get_vqa_dataset()
processor = get_vilt_processor('dandelin/vilt-b32-finetuned-vqa')
loader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=functools.partial(
vilt_collate_fn,
processor=processor,
label2id=model.config.label2id,
),
)
return model, loader, processor