Anonumous's picture
Update app.py
35b21d7
import torch
import gradio as gr
from PIL import Image
import torch.nn as nn
from torch.nn import functional as nnf
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import cv2
from PIL import Image
from typing import Tuple, Optional, Union
import clip
gpt_model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'
class MLP(nn.Module):
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
# @autocast()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def freeze(
model,
freeze_emb=False,
freeze_ln=False,
freeze_attn=True,
freeze_ff=True,
freeze_other=False,
):
for name, p in model.named_parameters():
# freeze all parameters except the layernorm and positional embeddings
name = name.lower()
if 'ln' in name or 'norm' in name:
p.requires_grad = not freeze_ln
elif 'embeddings' in name:
p.requires_grad = not freeze_emb
elif 'mlp' in name:
p.requires_grad = not freeze_ff
elif 'attn' in name:
p.requires_grad = not freeze_attn
else:
p.requires_grad = not freeze_other
return model
class ClipCaptionModel(nn.Module):
def __init__(self, prefix_length: int, prefix_size: int = 768):
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
"""
ru gpts shit
"""
self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
self.gpt_embedding_size * prefix_length))
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
# @autocast()
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix.float()).view(-1, self.prefix_length, self.gpt_embedding_size)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
class ClipCaptionPrefix(ClipCaptionModel):
def parameters(self, recurse: bool = True):
return self.clip_project.parameters()
def train(self, mode: bool = True):
super(ClipCaptionPrefix, self).train(mode)
self.gpt.eval()
return self
def filter_ngrams(output_text):
a_pos = output_text.find(' Ответ:')
sec_a_pos = output_text.find(' Ответ:', a_pos + 1)
return output_text[:sec_a_pos]
def generate2(
model,
tokenizer,
tokens=None,
prompt='',
embed=None,
entry_count=1,
entry_length=67, # maximum number of words
top_p=0.98,
temperature=1.,
stop_token='.',
):
model.eval()
generated_num = 0
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
with torch.no_grad():
for entry_idx in range(entry_count):
if not tokens:
tokens = torch.tensor(tokenizer.encode(prompt))
# print('tokens',tokens)
tokens = tokens.unsqueeze(0).to(device)
emb_tokens = model.gpt.transformer.wte(tokens)
if embed is not None:
generated = torch.cat((embed, emb_tokens), dim=1)
else:
generated = emb_tokens
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
top_k = 2000
top_p = 0.98
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
decoder_inputs_embeds = next_token_embed
output_list = list(tokens.squeeze().cpu().numpy())
output_text = tokenizer.decode(output_list)
output_text = filter_ngrams(output_text)
generated_list.append(output_text)
return generated_list[0]
def read_image(path):
image = cv2.imread(path)
size = 196, 196
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image.thumbnail(size, Image.Resampling.LANCZOS)
return image
def create_emb(image):
text = "Вопрос: что происходит на изображении? Ответ: "
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
return (prefix, text)
def get_caption(prefix, prompt=''):
prefix = prefix.to(device)
with torch.no_grad():
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
if prompt:
generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
else:
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
return generated_text_prefix.replace('\n', ' ')
def get_ans(clip_emb, prompt):
output = get_caption(clip_emb, prompt=prompt)
ans = output[len(prompt):].strip()
return ans
device = 'cpu'
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3medium_based_on_gpt2')
prefix_length = 30
model_path = 'prefix_small_latest_gpt2_medium.pt'
model = ClipCaptionPrefix(prefix_length)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
model.eval()
def classify_image(inp):
print(type(inp))
inp = Image.fromarray(inp)
prefix, text = create_emb(inp)
ans = get_ans(prefix, text)
return ans
image = gr.inputs.Image(shape=(196, 196))
label = gr.outputs.Label(num_top_classes=3)
iface = gr.Interface(fn=classify_image, description="RuImage Captioning trained for a image2text task to predict caption of image", inputs=image, outputs="text", examples=[
["1.png"],
["2.png"],
["3.png"]
])
iface.launch()