explain-LXMERT / generic.py
WwYc's picture
Update generic.py
a1dc8e3 verified
raw
history blame
5.27 kB
import pylab
from lxmert.src.modeling_frcnn import GeneralizedRCNN
import lxmert.src.vqa_utils as utils
from lxmert.src.processing_image import Preprocess
from transformers import LxmertTokenizer
from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
from tqdm import tqdm
from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
import random
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import visualization
import requests
OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt"
ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt"
VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json"
class ModelUsage:
def __init__(self, use_lrp=False):
self.vqa_answers = utils.get_data(VQA_URL)
# load models and model components
self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned")
self.frcnn_cfg.MODEL.DEVICE = "cpu"
self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
self.image_preprocess = Preprocess(self.frcnn_cfg)
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased")
if use_lrp:
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
else:
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
self.lxmert_vqa.eval()
self.model = self.lxmert_vqa
# self.vqa_dataset = vqa_data.VQADataset(splits="valid")
def forward(self, item):
URL, question = item
self.image_file_path = URL
# run frcnn
images, sizes, scales_yx = self.image_preprocess(URL)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
question,
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
self.text_len = len(self.question_tokens)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
self.image_boxes_len = features.shape[1]
self.bboxes = output_dict.get("boxes")
self.output = self.lxmert_vqa(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
visual_feats=features,
visual_pos=normalized_boxes,
token_type_ids=inputs.token_type_ids,
return_dict=True,
output_attentions=False,
)
return self.output
model_lrp = ModelUsage(use_lrp=True)
lrp = GeneratorOurs(model_lrp)
baselines = GeneratorBaselines(model_lrp)
vqa_answers = utils.get_data(VQA_URL)
def save_image_vis(image_file_path, question):
R_t_t, R_t_i = lrp.generate_ours((image_file_path, quewtion), use_lrp=False,
normalize_self_attention=True,
method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
# bbox_scores = image_scores
_, top_bboxes_indices = image_scores.topk(k=1, dim=-1)
img = cv2.imread(image_file_path)
mask = torch.zeros(img.shape[0], img.shape[1])
for index in range(len(image_scores)):
[x, y, w, h] = model_lrp.bboxes[0][index]
curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item()
mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = mask.unsqueeze_(-1)
mask = mask.expand(img.shape)
img = img * mask.cpu().data.numpy()
# img = Image.fromarray(np.uint8(img)).convert('RGB')
cv2.imwrite(
'lxmert/lxmert/experiments/paper/new.jpg', img)
img = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)]
html1 = visualization.visualize_text(vis_data_records)
answer = vqa_answers[model_lrp.output.question_answering_score.argmax()]
return img, html1.data, answer