|
import os |
|
from pprint import pprint |
|
from configs.config import parser |
|
from dataset.data_module import DataModule |
|
from models.R2GenGPT import R2GenGPT |
|
import torch |
|
from transformers import BertTokenizer, AutoImageProcessor |
|
from PIL import Image |
|
import numpy as np |
|
import streamlit as st |
|
from lightning.pytorch import seed_everything |
|
|
|
|
|
|
|
|
|
|
|
def load_model(args): |
|
model = R2GenGPT(args) |
|
model.eval() |
|
model.freeze() |
|
return model |
|
|
|
|
|
def _parse_image(vit_feature_extractor, img): |
|
pixel_values = vit_feature_extractor(img, return_tensors="pt").pixel_values |
|
return pixel_values[0] |
|
|
|
|
|
def generate_predictions(image_path, vit_feature_extractor, model): |
|
model.llama_tokenizer.padding_side = "right" |
|
|
|
with Image.open(image_path) as pil: |
|
array = np.array(pil, dtype=np.uint8) |
|
if array.shape[-1] != 3 or len(array.shape) != 3: |
|
array = np.array(pil.convert("RGB"), dtype=np.uint8) |
|
image = _parse_image(vit_feature_extractor, array) |
|
image = image.unsqueeze(0) |
|
|
|
image = image.to(device='cuda:0') |
|
|
|
print("Model Encoding for Image: ", model.encode_img(image)) |
|
try: |
|
img_embeds, atts_img = model.encode_img(image) |
|
print("Image embeddings in try blk", img_embeds) |
|
|
|
print("Try block for Image Embeddings \n") |
|
|
|
except Exception as e: |
|
st.error(e) |
|
print(st.error(e)) |
|
print("Except block for Image embeddings \n") |
|
|
|
|
|
img_embeds = model.layer_norm(img_embeds) |
|
img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img) |
|
print("Image embeddings: ", img_embeds) |
|
|
|
|
|
batch_size = img_embeds.shape[0] |
|
print("Batch size printed: ", batch_size) |
|
bos = torch.ones([batch_size, 1], |
|
dtype=atts_img.dtype, |
|
device=atts_img.device) * model.llama_tokenizer.bos_token_id |
|
bos_embeds = model.embed_tokens(bos) |
|
atts_bos = atts_img[:, :1] |
|
print("Attention: ", atts_bos) |
|
|
|
inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) |
|
print("Shape of Input emb", inputs_embeds) |
|
inputs_embeds = inputs_embeds.type(torch.float16) |
|
attention_mask = torch.cat([atts_bos, atts_img], dim=1) |
|
print("Shape of Attention mask: ", attention_mask) |
|
|
|
try: |
|
with torch.no_grad(): |
|
outputs = model.llama_model.generate(inputs_embeds=inputs_embeds) |
|
print("output", outputs) |
|
except Exception as e: |
|
st.error(e) |
|
return [] |
|
|
|
hypo = [model.decode(i) for i in outputs] |
|
print("Generated Report :", hypo) |
|
return hypo |
|
|
|
|
|
def inference(args, uploaded_file): |
|
model = load_model(args) |
|
vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model) |
|
|
|
with open("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
predictions = generate_predictions("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", vit_feature_extractor, model) |
|
print("Predictions: ", predictions) |
|
os.remove("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg") |
|
|
|
return predictions |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
pprint(vars(args)) |
|
seed_everything(42, workers=True) |
|
|
|
|
|
model = load_model(args) |
|
vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model) |
|
predictions = generate_predictions("/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg", vit_feature_extractor, model) |
|
print("Predictions: ", predictions) |
|
|
|
print("Inference: ", inference(args, "/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|