File size: 7,979 Bytes
dc7dc01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04fc647
dc7dc01
 
 
 
 
 
 
 
 
 
 
 
 
b1e392b
dc7dc01
3cf7bc9
dc7dc01
55e2d03
dc7dc01
 
 
35b21d7
 
 
 
 
dc7dc01
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import torch

import gradio as gr

from PIL import Image

import torch.nn as nn
from torch.nn import functional as nnf
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import cv2
from PIL import Image
from typing import Tuple, Optional, Union

import clip

gpt_model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'


class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)

    # @autocast()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


def freeze(
        model,
        freeze_emb=False,
        freeze_ln=False,
        freeze_attn=True,
        freeze_ff=True,
        freeze_other=False,
):
    for name, p in model.named_parameters():
        # freeze all parameters except the layernorm and positional embeddings
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other

    return model


class ClipCaptionModel(nn.Module):
    def __init__(self, prefix_length: int, prefix_size: int = 768):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        """
        ru gpts shit
        """
        self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)

        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                 self.gpt_embedding_size * prefix_length))

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    # @autocast()
    def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        embedding_text = self.gpt.transformer.wte(tokens)

        prefix_projections = self.clip_project(prefix.float()).view(-1, self.prefix_length, self.gpt_embedding_size)

        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)

        return out


class ClipCaptionPrefix(ClipCaptionModel):
    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self


def filter_ngrams(output_text):
    a_pos = output_text.find(' Ответ:')
    sec_a_pos = output_text.find(' Ответ:', a_pos + 1)
    return output_text[:sec_a_pos]


def generate2(
        model,
        tokenizer,
        tokens=None,
        prompt='',
        embed=None,
        entry_count=1,
        entry_length=67,  # maximum number of words
        top_p=0.98,
        temperature=1.,
        stop_token='.',
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():
        for entry_idx in range(entry_count):
            if not tokens:
                tokens = torch.tensor(tokenizer.encode(prompt))
                # print('tokens',tokens)
                tokens = tokens.unsqueeze(0).to(device)

            emb_tokens = model.gpt.transformer.wte(tokens)

            if embed is not None:
                generated = torch.cat((embed, emb_tokens), dim=1)
            else:
                generated = emb_tokens

            for i in range(entry_length):
                outputs = model.gpt(inputs_embeds=generated)

                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                top_k = 2000
                top_p = 0.98
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)

                if stop_token_index == next_token.item():
                    break

                decoder_inputs_embeds = next_token_embed

            output_list = list(tokens.squeeze().cpu().numpy())

            output_text = tokenizer.decode(output_list)
            output_text = filter_ngrams(output_text)
            generated_list.append(output_text)

    return generated_list[0]


def read_image(path):
    image = cv2.imread(path)

    size = 196, 196
    image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    image.thumbnail(size, Image.Resampling.LANCZOS)

    return image


def create_emb(image):
    text = "Вопрос: что происходит на изображении? Ответ:  "
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
        prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
    return (prefix, text)


def get_caption(prefix, prompt=''):
    prefix = prefix.to(device)
    with torch.no_grad():
        prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
        if prompt:
            generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
        else:
            generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
    return generated_text_prefix.replace('\n', ' ')


def get_ans(clip_emb, prompt):
    output = get_caption(clip_emb, prompt=prompt)
    ans = output[len(prompt):].strip()
    return ans


device = 'cpu'
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3medium_based_on_gpt2')
prefix_length = 30
model_path = 'prefix_small_latest_gpt2_medium.pt'
model = ClipCaptionPrefix(prefix_length)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
model.eval()



def classify_image(inp):
  print(type(inp))
  inp =  Image.fromarray(inp)
  prefix, text = create_emb(inp)
  ans = get_ans(prefix, text)
  return ans

image = gr.inputs.Image(shape=(196, 196))
label = gr.outputs.Label(num_top_classes=3)


iface = gr.Interface(fn=classify_image, description="RuImage Captioning  trained for a image2text task to predict caption of image", inputs=image, outputs="text", examples=[
    ["1.png"],
    ["2.png"],
    ["3.png"]
])
iface.launch()