Spaces:
Runtime error
Runtime error
File size: 8,073 Bytes
d2914a7 ae4f84a d2914a7 8c12d63 3de070a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from PIL import Image
import clip
import numpy as np
import cv2
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
top_k = min(top_k, logits.size(-1))
if top_k > 0:
# Remove all tokens with a probability less than the last token in the top-k tokens
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Compute cumulative probabilities of sorted tokens
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probabilities > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Back to unsorted indices and set them to -infinity
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
indices_to_remove = logits < threshold
logits[indices_to_remove] = filter_value
return logits
class ImageEncoder(nn.Module):
def __init__(self):
super(ImageEncoder, self).__init__()
self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode
def forward(self, x):
"""
Expects a tensor of size (batch_size, 3, 224, 224)
"""
with torch.no_grad():
x = x.type(self.encoder.visual.conv1.weight.dtype)
x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.encoder.visual.positional_embedding.to(x.dtype)
x = self.encoder.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.encoder.visual.transformer(x)
grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768)
grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])
return grid_feats.float()
def change_requires_grad(model, req_grad):
for p in model.parameters():
p.requires_grad = req_grad
def load_checkpoint(ckpt_path, epoch):
model_name = 'nle_model_{}'.format(str(epoch))
tokenizer_name = 'nle_gpt2_tokenizer_0'
tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) # load tokenizer
model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) # load model with config
return tokenizer, model
def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
because_token = tokenizer.convert_tokens_to_ids('Δ because')
max_len = 20
current_output = []
img_embeddings = image_encoder(img)
always_exp = False
with torch.no_grad():
for step in range(max_len + 1):
if step == max_len:
break
outputs = model(input_ids=input_ids,
past_key_values=None,
attention_mask=None,
token_type_ids=segment_ids,
position_ids=None,
encoder_hidden_states=img_embeddings,
encoder_attention_mask=None,
labels=None,
use_cache=False,
output_attentions=True,
return_dict=True)
lm_logits = outputs.logits
xa_maps = outputs.cross_attentions
logits = lm_logits[0, -1, :] / temperature
logits = top_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
if prev.item() in special_tokens_ids:
break
# take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
if not always_exp:
if prev.item() != because_token:
new_segment = special_tokens_ids[-2] # answer segment
else:
new_segment = special_tokens_ids[-1] # explanation segment
always_exp = True
else:
new_segment = special_tokens_ids[-1] # explanation segment
new_segment = torch.LongTensor([new_segment]).to(device)
current_output.append(prev.item())
input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()
return decoded_sequences, xa_maps
def get_inputs(tokenizer):
a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<answer>', '<explanation>'])
tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is")
segment_ids = [a_segment_id] * len(tokens)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.tensor(input_ids, dtype=torch.long)
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
img_size = 224
ckpt_path = 'ACTX_p/'
max_seq_len = 30
load_from_epoch = 5
no_sample = True
top_k = 0
top_p = 0.9
temperature = 1
image_encoder = ImageEncoder().to(device)
change_requires_grad(image_encoder, False)
tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
model.eval()
img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
def inference(raw_image):
oimg = raw_image.convert('RGB').resize((224,224))
img = img_transform(oimg).unsqueeze(0).to(device)
input_ids, segment_ids = get_inputs(tokenizer)
seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
last_am = xa_maps[-1].mean(1)[0]
mask = last_am[0, :].reshape(14,14).cpu().numpy()
mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
attention_map = (mask * oimg).astype("uint8")
splitted_seq = seq.split("because")
return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)
inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")]
outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]
title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)"
gr.Interface(inference, inputs, outputs, title=title).launch()
#
|