Anonumous commited on
Commit
dc7dc01
1 Parent(s): 1085df7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import gradio as gr
4
+
5
+ from PIL import Image
6
+
7
+ import torch.nn as nn
8
+ from torch.nn import functional as nnf
9
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
10
+ import cv2
11
+ from PIL import Image
12
+ from typing import Tuple, Optional, Union
13
+
14
+ import clip
15
+
16
+ gpt_model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'
17
+
18
+
19
+ class MLP(nn.Module):
20
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
21
+ super(MLP, self).__init__()
22
+ layers = []
23
+ for i in range(len(sizes) - 1):
24
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
25
+ if i < len(sizes) - 2:
26
+ layers.append(act())
27
+ self.model = nn.Sequential(*layers)
28
+
29
+ # @autocast()
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ return self.model(x)
32
+
33
+
34
+ def freeze(
35
+ model,
36
+ freeze_emb=False,
37
+ freeze_ln=False,
38
+ freeze_attn=True,
39
+ freeze_ff=True,
40
+ freeze_other=False,
41
+ ):
42
+ for name, p in model.named_parameters():
43
+ # freeze all parameters except the layernorm and positional embeddings
44
+ name = name.lower()
45
+ if 'ln' in name or 'norm' in name:
46
+ p.requires_grad = not freeze_ln
47
+ elif 'embeddings' in name:
48
+ p.requires_grad = not freeze_emb
49
+ elif 'mlp' in name:
50
+ p.requires_grad = not freeze_ff
51
+ elif 'attn' in name:
52
+ p.requires_grad = not freeze_attn
53
+ else:
54
+ p.requires_grad = not freeze_other
55
+
56
+ return model
57
+
58
+
59
+ class ClipCaptionModel(nn.Module):
60
+ def __init__(self, prefix_length: int, prefix_size: int = 768):
61
+ super(ClipCaptionModel, self).__init__()
62
+ self.prefix_length = prefix_length
63
+ """
64
+ ru gpts shit
65
+ """
66
+ self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)
67
+
68
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
69
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
70
+ self.gpt_embedding_size * prefix_length))
71
+
72
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
73
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
74
+
75
+ # @autocast()
76
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
77
+ labels: Optional[torch.Tensor] = None):
78
+ embedding_text = self.gpt.transformer.wte(tokens)
79
+
80
+ prefix_projections = self.clip_project(prefix.float()).view(-1, self.prefix_length, self.gpt_embedding_size)
81
+
82
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
83
+ if labels is not None:
84
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
85
+ labels = torch.cat((dummy_token, tokens), dim=1)
86
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
87
+
88
+ return out
89
+
90
+
91
+ class ClipCaptionPrefix(ClipCaptionModel):
92
+ def parameters(self, recurse: bool = True):
93
+ return self.clip_project.parameters()
94
+
95
+ def train(self, mode: bool = True):
96
+ super(ClipCaptionPrefix, self).train(mode)
97
+ self.gpt.eval()
98
+ return self
99
+
100
+
101
+ def filter_ngrams(output_text):
102
+ a_pos = output_text.find(' Ответ:')
103
+ sec_a_pos = output_text.find(' Ответ:', a_pos + 1)
104
+ return output_text[:sec_a_pos]
105
+
106
+
107
+ def generate2(
108
+ model,
109
+ tokenizer,
110
+ tokens=None,
111
+ prompt='',
112
+ embed=None,
113
+ entry_count=1,
114
+ entry_length=67, # maximum number of words
115
+ top_p=0.98,
116
+ temperature=1.,
117
+ stop_token='.',
118
+ ):
119
+ model.eval()
120
+ generated_num = 0
121
+ generated_list = []
122
+ stop_token_index = tokenizer.encode(stop_token)[0]
123
+ filter_value = -float("Inf")
124
+ device = next(model.parameters()).device
125
+
126
+ with torch.no_grad():
127
+ for entry_idx in range(entry_count):
128
+ if not tokens:
129
+ tokens = torch.tensor(tokenizer.encode(prompt))
130
+ # print('tokens',tokens)
131
+ tokens = tokens.unsqueeze(0).to(device)
132
+
133
+ emb_tokens = model.gpt.transformer.wte(tokens)
134
+
135
+ if embed is not None:
136
+ generated = torch.cat((embed, emb_tokens), dim=1)
137
+ else:
138
+ generated = emb_tokens
139
+
140
+ for i in range(entry_length):
141
+ outputs = model.gpt(inputs_embeds=generated)
142
+
143
+ logits = outputs.logits
144
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
145
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
146
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
147
+ sorted_indices_to_remove = cumulative_probs > top_p
148
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
149
+ sorted_indices_to_remove[..., 0] = 0
150
+
151
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
152
+ logits[:, indices_to_remove] = filter_value
153
+
154
+ top_k = 2000
155
+ top_p = 0.98
156
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
157
+ next_token_embed = model.gpt.transformer.wte(next_token)
158
+ if tokens is None:
159
+ tokens = next_token
160
+ else:
161
+ tokens = torch.cat((tokens, next_token), dim=1)
162
+ generated = torch.cat((generated, next_token_embed), dim=1)
163
+
164
+ if stop_token_index == next_token.item():
165
+ break
166
+
167
+ decoder_inputs_embeds = next_token_embed
168
+
169
+ output_list = list(tokens.squeeze().cpu().numpy())
170
+
171
+ output_text = tokenizer.decode(output_list)
172
+ output_text = filter_ngrams(output_text)
173
+ generated_list.append(output_text)
174
+
175
+ return generated_list[0]
176
+
177
+
178
+ def read_image(path):
179
+ image = cv2.imread(path)
180
+
181
+ size = 196, 196
182
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
183
+ image.thumbnail(size, Image.Resampling.LANCZOS)
184
+
185
+ return image
186
+
187
+
188
+ def create_emb(image):
189
+ text = "Вопрос: что происходит на изображении? Ответ: "
190
+ image = preprocess(image).unsqueeze(0).to(device)
191
+ with torch.no_grad():
192
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
193
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
194
+ return (prefix, text)
195
+
196
+
197
+ def get_caption(prefix, prompt=''):
198
+ prefix = prefix.to(device)
199
+ with torch.no_grad():
200
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
201
+ if prompt:
202
+ generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
203
+ else:
204
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
205
+ return generated_text_prefix.replace('\n', ' ')
206
+
207
+
208
+ def get_ans(clip_emb, prompt):
209
+ output = get_caption(clip_emb, prompt=prompt)
210
+ ans = output[len(prompt):].strip()
211
+ return ans
212
+
213
+
214
+ device = 'cpu'
215
+ clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
216
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3medium_based_on_gpt2')
217
+ prefix_length = 30
218
+ model_path = 'prefix_small_latest_gpt2_medium.pt'
219
+ model = ClipCaptionPrefix(prefix_length)
220
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
221
+ model.to(device)
222
+ model.eval()
223
+
224
+
225
+
226
+ def classify_image(inp):
227
+ print(type(inp))
228
+ inp = Image.fromarray(inp)
229
+ prefix, text = create_emb(path_to_image)
230
+ ans = get_ans(prefix, text)
231
+ return texts
232
+
233
+ image = gr.inputs.Image(shape=(256, 256))
234
+ label = gr.outputs.Label(num_top_classes=3)
235
+
236
+
237
+ iface = gr.Interface(fn=classify_image, description="https://github.com/AlexWortega/ruImageCaptioning RuImage Captioning trained for a image2text task to predict caption of image by https://t.me/lovedeathtransformers Alex Wortega", inputs=image, outputs="text",])
238
+ iface.launch()