Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional, List | |
import os | |
import json | |
import requests | |
import functools | |
from io import BytesIO | |
from pathlib import Path | |
from urllib3 import disable_warnings | |
from urllib3.exceptions import InsecureRequestWarning | |
import torch | |
import torchvision | |
from torch import Tensor | |
from torch.nn.modules import Module | |
from torch.utils.data import Dataset, Subset, DataLoader | |
# from torchtext.datasets import IMDB | |
from transformers import BertTokenizer, BertForSequenceClassification | |
from transformers import ViltForQuestionAnswering, ViltProcessor | |
from tqdm import tqdm | |
from PIL import Image | |
# datasets | |
class ImageNetDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.root_dir = root_dir | |
self.img_dir = os.path.join(self.root_dir, 'samples/') | |
self.label_dir = os.path.join(self.root_dir, 'imagenet_class_index.json') | |
with open(self.label_dir) as json_data: | |
self.idx_to_labels = json.load(json_data) | |
self.img_names = os.listdir(self.img_dir) | |
self.img_names.sort() | |
self.transform = transform | |
def __len__(self): | |
return len(self.img_names) | |
def __getitem__(self, idx): | |
img_path = os.path.join(self.img_dir, self.img_names[idx]) | |
image = Image.open(img_path).convert('RGB') | |
label = idx | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
def idx_to_label(self, idx): | |
return self.idx_to_labels[str(idx)][1] | |
def get_imagenet_dataset( | |
transform, | |
subset_size: int=100, # ignored if indices is not None | |
root_dir="./data/ImageNet", | |
indices: Optional[List[int]]=None, | |
): | |
os.chdir(Path(__file__).parent) # ensure path | |
dataset = ImageNetDataset(root_dir=root_dir, transform=transform) | |
if indices is not None: | |
return Subset(dataset, indices=indices) | |
indices = list(range(len(dataset))) | |
subset = Subset(dataset, indices=indices[:subset_size]) | |
return subset | |
class IMDBDataset(Dataset): | |
def __init__(self, split='test'): | |
super().__init__() | |
data_iter = IMDB(split=split) | |
self.annotations = [(line, label-1) for label, line in tqdm(data_iter)] | |
def __len__(self): | |
return len(self.annotations) | |
def __getitem__(self, idx): | |
return self.annotations[idx] | |
def get_imdb_dataset(split='test'): | |
return IMDBDataset(split=split) | |
disable_warnings(InsecureRequestWarning) | |
class VQADataset(Dataset): | |
def __init__(self): | |
super().__init__() | |
res = requests.get('https://visualqa.org/balanced_data.json') | |
self.annotations = eval(res.text) | |
def __len__(self): | |
return len(self.annotations) | |
def __getitem__(self, idx): | |
data = self.annotations[idx] | |
if isinstance(data['original_image'], str): | |
print(f"Requesting {data['original_image']}...") | |
res = requests.get(data['original_image'], verify=False) | |
img = Image.open(BytesIO(res.content)).convert('RGB') | |
data['original_image'] = img | |
return data['original_image'], data['question'], data['original_answer'] | |
def get_vqa_dataset(): | |
return VQADataset() | |
# models | |
def get_torchvision_model(model_name): | |
weights = torchvision.models.get_model_weights(model_name).DEFAULT | |
model = torchvision.models.get_model(model_name, weights=weights).eval() | |
transform = weights.transforms() | |
return model, transform | |
class Bert(BertForSequenceClassification): | |
def forward(self, input_ids, token_type_ids, attention_mask): | |
return super().forward( | |
input_ids=input_ids, | |
token_type_ids=token_type_ids, | |
attention_mask=attention_mask | |
).logits | |
def get_bert_model(model_name, num_labels): | |
return Bert.from_pretrained(model_name, num_labels=num_labels) | |
class Vilt(ViltForQuestionAnswering): | |
def forward( | |
self, | |
pixel_values, | |
input_ids, | |
token_type_ids, | |
attention_mask, | |
pixel_mask, | |
): | |
return super().forward( | |
input_ids=input_ids, | |
token_type_ids=token_type_ids, | |
attention_mask=attention_mask, | |
pixel_values=pixel_values, | |
pixel_mask=pixel_mask, | |
).logits | |
def get_vilt_model(model_name): | |
return Vilt.from_pretrained(model_name) | |
# utils | |
img_to_np = lambda img: img.permute(1, 2, 0).detach().numpy() | |
def denormalize_image(inputs, mean, std): | |
return img_to_np( | |
inputs | |
* Tensor(std)[:, None, None] | |
+ Tensor(mean)[:, None, None] | |
) | |
def bert_collate_fn(batch, tokenizer=None): | |
inputs = tokenizer( | |
[d[0] for d in batch], | |
padding=True, | |
truncation=True, | |
return_tensors='pt', | |
) | |
labels = torch.tensor([d[1] for d in batch]) | |
return tuple(inputs.values()), labels | |
def get_bert_tokenizer(model_name): | |
return BertTokenizer.from_pretrained(model_name) | |
def get_vilt_processor(model_name): | |
return ViltProcessor.from_pretrained(model_name) | |
def vilt_collate_fn(batch, processor=None, label2id=None): | |
imgs = [d[0] for d in batch] | |
qsts = [d[1] for d in batch] | |
inputs = processor( | |
images=imgs, | |
text=qsts, | |
padding=True, | |
truncation=True, | |
return_tensors='pt', | |
) | |
labels = torch.tensor([label2id[d[2]] for d in batch]) | |
return ( | |
inputs['pixel_values'], | |
inputs['input_ids'], | |
inputs['token_type_ids'], | |
inputs['attention_mask'], | |
inputs['pixel_mask'], | |
labels, | |
) | |
def load_model_and_dataloader_for_tutorial(modality, device): | |
if modality == 'image': | |
model, transform = get_torchvision_model('resnet18') | |
model = model.to(device) | |
model.eval() | |
dataset = get_imagenet_dataset(transform) | |
loader = DataLoader(dataset, batch_size=8, shuffle=False) | |
return model, loader, transform | |
elif modality == 'text': | |
model = get_bert_model('fabriceyhc/bert-base-uncased-imdb', num_labels=2) | |
model = model.to(device) | |
model.eval() | |
dataset = get_imdb_dataset(split='test') | |
tokenizer = get_bert_tokenizer('fabriceyhc/bert-base-uncased-imdb') | |
loader = DataLoader( | |
dataset, | |
batch_size=8, | |
shuffle=False, | |
collate_fn=functools.partial(bert_collate_fn, tokenizer=tokenizer) | |
) | |
return model, loader, tokenizer | |
elif modality == ('image', 'text'): | |
model = get_vilt_model('dandelin/vilt-b32-finetuned-vqa') | |
model.to(device) | |
model.eval() | |
dataset = get_vqa_dataset() | |
processor = get_vilt_processor('dandelin/vilt-b32-finetuned-vqa') | |
loader = DataLoader( | |
dataset, | |
batch_size=2, | |
shuffle=False, | |
collate_fn=functools.partial( | |
vilt_collate_fn, | |
processor=processor, | |
label2id=model.config.label2id, | |
), | |
) | |
return model, loader, processor | |