sohaibcs1 commited on
Commit
22f2dae
1 Parent(s): 06fdca2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
3
+ os.system("gdown https://drive.google.com/uc?id=1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX")
4
+ import clip
5
+ import os
6
+ from torch import nn
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as nnf
10
+ import sys
11
+ from typing import Tuple, List, Union, Optional
12
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
13
+ from tqdm import tqdm, trange
14
+ import skimage.io as io
15
+ import PIL.Image
16
+ import gradio as gr
17
+
18
+ N = type(None)
19
+ V = np.array
20
+ ARRAY = np.ndarray
21
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
22
+ VS = Union[Tuple[V, ...], List[V]]
23
+ VN = Union[V, N]
24
+ VNS = Union[VS, N]
25
+ T = torch.Tensor
26
+ TS = Union[Tuple[T, ...], List[T]]
27
+ TN = Optional[T]
28
+ TNS = Union[Tuple[TN, ...], List[TN]]
29
+ TSN = Optional[TS]
30
+ TA = Union[T, ARRAY]
31
+
32
+
33
+ D = torch.device
34
+ CPU = torch.device('cpu')
35
+
36
+
37
+ def get_device(device_id: int) -> D:
38
+ if not torch.cuda.is_available():
39
+ return CPU
40
+ device_id = min(torch.cuda.device_count() - 1, device_id)
41
+ return torch.device(f'cuda:{device_id}')
42
+
43
+
44
+ CUDA = get_device
45
+
46
+ class MLP(nn.Module):
47
+
48
+ def forward(self, x: T) -> T:
49
+ return self.model(x)
50
+
51
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
52
+ super(MLP, self).__init__()
53
+ layers = []
54
+ for i in range(len(sizes) -1):
55
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
56
+ if i < len(sizes) - 2:
57
+ layers.append(act())
58
+ self.model = nn.Sequential(*layers)
59
+
60
+
61
+ class ClipCaptionModel(nn.Module):
62
+
63
+ #@functools.lru_cache #FIXME
64
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
65
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
66
+
67
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
68
+ embedding_text = self.gpt.transformer.wte(tokens)
69
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
70
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
71
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
72
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
73
+ if labels is not None:
74
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
75
+ labels = torch.cat((dummy_token, tokens), dim=1)
76
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
77
+ return out
78
+
79
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
80
+ super(ClipCaptionModel, self).__init__()
81
+ self.prefix_length = prefix_length
82
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
83
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
84
+ if prefix_length > 10: # not enough memory
85
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
86
+ else:
87
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
88
+
89
+
90
+ class ClipCaptionPrefix(ClipCaptionModel):
91
+
92
+ def parameters(self, recurse: bool = True):
93
+ return self.clip_project.parameters()
94
+
95
+ def train(self, mode: bool = True):
96
+ super(ClipCaptionPrefix, self).train(mode)
97
+ self.gpt.eval()
98
+ return self
99
+
100
+
101
+ #@title Caption prediction
102
+
103
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
104
+ entry_length=67, temperature=1., stop_token: str = '.'):
105
+
106
+ model.eval()
107
+ stop_token_index = tokenizer.encode(stop_token)[0]
108
+ tokens = None
109
+ scores = None
110
+ device = next(model.parameters()).device
111
+ seq_lengths = torch.ones(beam_size, device=device)
112
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
113
+ with torch.no_grad():
114
+ if embed is not None:
115
+ generated = embed
116
+ else:
117
+ if tokens is None:
118
+ tokens = torch.tensor(tokenizer.encode(prompt))
119
+ tokens = tokens.unsqueeze(0).to(device)
120
+ generated = model.gpt.transformer.wte(tokens)
121
+ for i in range(entry_length):
122
+ outputs = model.gpt(inputs_embeds=generated)
123
+ logits = outputs.logits
124
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
125
+ logits = logits.softmax(-1).log()
126
+ if scores is None:
127
+ scores, next_tokens = logits.topk(beam_size, -1)
128
+ generated = generated.expand(beam_size, *generated.shape[1:])
129
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
130
+ if tokens is None:
131
+ tokens = next_tokens
132
+ else:
133
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
134
+ tokens = torch.cat((tokens, next_tokens), dim=1)
135
+ else:
136
+ logits[is_stopped] = -float(np.inf)
137
+ logits[is_stopped, 0] = 0
138
+ scores_sum = scores[:, None] + logits
139
+ seq_lengths[~is_stopped] += 1
140
+ scores_sum_average = scores_sum / seq_lengths[:, None]
141
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
142
+ next_tokens_source = next_tokens // scores_sum.shape[1]
143
+ seq_lengths = seq_lengths[next_tokens_source]
144
+ next_tokens = next_tokens % scores_sum.shape[1]
145
+ next_tokens = next_tokens.unsqueeze(1)
146
+ tokens = tokens[next_tokens_source]
147
+ tokens = torch.cat((tokens, next_tokens), dim=1)
148
+ generated = generated[next_tokens_source]
149
+ scores = scores_sum_average * seq_lengths
150
+ is_stopped = is_stopped[next_tokens_source]
151
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
152
+ generated = torch.cat((generated, next_token_embed), dim=1)
153
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
154
+ if is_stopped.all():
155
+ break
156
+ scores = scores / seq_lengths
157
+ output_list = tokens.cpu().numpy()
158
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
159
+ order = scores.argsort(descending=True)
160
+ output_texts = [output_texts[i] for i in order]
161
+ return output_texts
162
+
163
+
164
+ def generate2(
165
+ model,
166
+ tokenizer,
167
+ tokens=None,
168
+ prompt=None,
169
+ embed=None,
170
+ entry_count=1,
171
+ entry_length=67, # maximum number of words
172
+ top_p=0.8,
173
+ temperature=1.,
174
+ stop_token: str = '.',
175
+ ):
176
+ model.eval()
177
+ generated_num = 0
178
+ generated_list = []
179
+ stop_token_index = tokenizer.encode(stop_token)[0]
180
+ filter_value = -float("Inf")
181
+ device = next(model.parameters()).device
182
+
183
+ with torch.no_grad():
184
+
185
+ for entry_idx in trange(entry_count):
186
+ if embed is not None:
187
+ generated = embed
188
+ else:
189
+ if tokens is None:
190
+ tokens = torch.tensor(tokenizer.encode(prompt))
191
+ tokens = tokens.unsqueeze(0).to(device)
192
+
193
+ generated = model.gpt.transformer.wte(tokens)
194
+
195
+ for i in range(entry_length):
196
+
197
+ outputs = model.gpt(inputs_embeds=generated)
198
+ logits = outputs.logits
199
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
200
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
201
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
202
+ sorted_indices_to_remove = cumulative_probs > top_p
203
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
204
+ ..., :-1
205
+ ].clone()
206
+ sorted_indices_to_remove[..., 0] = 0
207
+
208
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
209
+ logits[:, indices_to_remove] = filter_value
210
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
211
+ next_token_embed = model.gpt.transformer.wte(next_token)
212
+ if tokens is None:
213
+ tokens = next_token
214
+ else:
215
+ tokens = torch.cat((tokens, next_token), dim=1)
216
+ generated = torch.cat((generated, next_token_embed), dim=1)
217
+ if stop_token_index == next_token.item():
218
+ break
219
+
220
+ output_list = list(tokens.squeeze().cpu().numpy())
221
+ output_text = tokenizer.decode(output_list)
222
+ generated_list.append(output_text)
223
+
224
+ return generated_list[0]
225
+
226
+ is_gpu = False
227
+ device = CUDA(0) if is_gpu else "cpu"
228
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
229
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
230
+
231
+ def inference(img,model_name):
232
+ prefix_length = 10
233
+
234
+ model = ClipCaptionModel(prefix_length)
235
+
236
+ if model_name == "COCO":
237
+ model_path = 'coco_weights.pt'
238
+ else:
239
+ model_path = 'conceptual_weights.pt'
240
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
241
+ model = model.eval()
242
+ device = CUDA(0) if is_gpu else "cpu"
243
+ model = model.to(device)
244
+
245
+ use_beam_search = False
246
+ image = io.imread(img.name)
247
+ pil_image = PIL.Image.fromarray(image)
248
+ image = preprocess(pil_image).unsqueeze(0).to(device)
249
+ with torch.no_grad():
250
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
251
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
252
+ if use_beam_search:
253
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
254
+ else:
255
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
256
+ return generated_text_prefix
257
+
258
+ title = "ImageSummarizer"
259
+ description = "Gradio demo for Image Summarizer: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
260
+ article = "<p style='text-align: center'><a href='https://github.com/sohaibcs1/ImageSummarizer' target='_blank'>Github Repo</a></p>"
261
+
262
+ examples=[['water.jpeg',"COCO"]]
263
+ gr.Interface(
264
+ inference,
265
+ [gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices=["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
266
+ gr.outputs.Textbox(label="Output"),
267
+ title=title,
268
+ description=description,
269
+ article=article,
270
+ enable_queue=True,
271
+ examples=examples
272
+ ).launch(debug=True)