import pylab from lxmert.src.modeling_frcnn import GeneralizedRCNN import lxmert.src.vqa_utils as utils from lxmert.src.processing_image import Preprocess from transformers import LxmertTokenizer from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP from tqdm import tqdm from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation import random import numpy as np import cv2 import torch import matplotlib.pyplot as plt from PIL import Image import torchvision.transforms as transforms from captum.attr import visualization import requests OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt" ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt" VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json" class ModelUsage: def __init__(self, use_lrp=False): self.vqa_answers = utils.get_data(VQA_URL) # load models and model components self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned") self.frcnn_cfg.MODEL.DEVICE = "cpu" self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) self.image_preprocess = Preprocess(self.frcnn_cfg) self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased") if use_lrp: self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") else: self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") self.lxmert_vqa.eval() self.model = self.lxmert_vqa # self.vqa_dataset = vqa_data.VQADataset(splits="valid") def forward(self, item): URL, question = item self.image_file_path = URL # run frcnn images, sizes, scales_yx = self.image_preprocess(URL) output_dict = self.frcnn( images, sizes, scales_yx=scales_yx, padding="max_detections", max_detections=self.frcnn_cfg.max_detections, return_tensors="pt" ) inputs = self.lxmert_tokenizer( question, truncation=True, return_token_type_ids=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt" ) self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) self.text_len = len(self.question_tokens) # Very important that the boxes are normalized normalized_boxes = output_dict.get("normalized_boxes") features = output_dict.get("roi_features") self.image_boxes_len = features.shape[1] self.bboxes = output_dict.get("boxes") self.output = self.lxmert_vqa( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, visual_feats=features, visual_pos=normalized_boxes, token_type_ids=inputs.token_type_ids, return_dict=True, output_attentions=False, ) return self.output model_lrp = ModelUsage(use_lrp=True) lrp = GeneratorOurs(model_lrp) baselines = GeneratorBaselines(model_lrp) vqa_answers = utils.get_data(VQA_URL) def save_image_vis(image_file_path, question): R_t_t, R_t_i = lrp.generate_ours((image_file_path, quewtion), use_lrp=False, normalize_self_attention=True, method_name="ours") image_scores = R_t_i[0] text_scores = R_t_t[0] # bbox_scores = image_scores _, top_bboxes_indices = image_scores.topk(k=1, dim=-1) img = cv2.imread(image_file_path) mask = torch.zeros(img.shape[0], img.shape[1]) for index in range(len(image_scores)): [x, y, w, h] = model_lrp.bboxes[0][index] curr_score_tensor = mask[int(y):int(h), int(x):int(w)] new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item() mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)]) mask = (mask - mask.min()) / (mask.max() - mask.min()) mask = mask.unsqueeze_(-1) mask = mask.expand(img.shape) img = img * mask.cpu().data.numpy() # img = Image.fromarray(np.uint8(img)).convert('RGB') cv2.imwrite( 'lxmert/lxmert/experiments/paper/new.jpg', img) img = Image.open('lxmert/lxmert/experiments/paper/new.jpg') text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min()) vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)] html1 = visualization.visualize_text(vis_data_records) answer = vqa_answers[model_lrp.output.question_answering_score.argmax()] return img, html1.data, answer