File size: 4,759 Bytes
7569ad7 7755c1c 7569ad7 7755c1c 7569ad7 7755c1c 7569ad7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from pprint import pprint
from configs.config import parser
from dataset.data_module import DataModule
from models.R2GenGPT import R2GenGPT
import torch
from transformers import BertTokenizer, AutoImageProcessor
from PIL import Image
import numpy as np
import streamlit as st
from lightning.pytorch import seed_everything
# Initialize the app
# st.title("Chest X-ray Report Generator")
# Function to load the model
def load_model(args):
model = R2GenGPT(args)
model.eval()
model.freeze()
return model
# Function to parse image
def _parse_image(vit_feature_extractor, img):
pixel_values = vit_feature_extractor(img, return_tensors="pt").pixel_values
return pixel_values[0]
# Function to generate predictions
def generate_predictions(image_path, vit_feature_extractor, model):
model.llama_tokenizer.padding_side = "right"
with Image.open(image_path) as pil:
array = np.array(pil, dtype=np.uint8)
if array.shape[-1] != 3 or len(array.shape) != 3:
array = np.array(pil.convert("RGB"), dtype=np.uint8)
image = _parse_image(vit_feature_extractor, array)
image = image.unsqueeze(0)
# image = image[None, :]
image = image.to(device='cuda:0')
print("Model Encoding for Image: ", model.encode_img(image))
try:
img_embeds, atts_img = model.encode_img(image)
print("Image embeddings in try blk", img_embeds)
print("Try block for Image Embeddings \n")
except Exception as e:
st.error(e)
print(st.error(e))
print("Except block for Image embeddings \n")
# return []
img_embeds = model.layer_norm(img_embeds)
img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img)
print("Image embeddings: ", img_embeds)
batch_size = img_embeds.shape[0]
print("Batch size printed: ", batch_size)
bos = torch.ones([batch_size, 1],
dtype=atts_img.dtype,
device=atts_img.device) * model.llama_tokenizer.bos_token_id
bos_embeds = model.embed_tokens(bos)
atts_bos = atts_img[:, :1]
print("Attention: ", atts_bos)
inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
print("Shape of Input emb", inputs_embeds)
inputs_embeds = inputs_embeds.type(torch.float16)
attention_mask = torch.cat([atts_bos, atts_img], dim=1)
print("Shape of Attention mask: ", attention_mask)
try:
with torch.no_grad():
outputs = model.llama_model.generate(inputs_embeds=inputs_embeds)
print("output", outputs)
except Exception as e:
st.error(e)
return []
hypo = [model.decode(i) for i in outputs]
print("Generated Report :", hypo)
return hypo
# Function to perform inference
def inference(args, uploaded_file):
model = load_model(args)
vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
with open("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", "wb") as f:
f.write(uploaded_file.getbuffer())
predictions = generate_predictions("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", vit_feature_extractor, model)
print("Predictions: ", predictions)
os.remove("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg")
return predictions
# Main function
def main():
#parser = argparse.ArgumentParser()
# other arguments
#parser.add_argument('--file', type=open, action=LoadFromFile)
args = parser.parse_args()
pprint(vars(args))
seed_everything(42, workers=True)
# File uploader for image
model = load_model(args)
vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
predictions = generate_predictions("/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg", vit_feature_extractor, model)
print("Predictions: ", predictions)
print("Inference: ", inference(args, "/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg"))
#uploaded_file = st.file_uploader("Choose a chest X-ray image...", type="jpg")
#if uploaded_file is not None:
# st.image(uploaded_file, caption='Uploaded Image.', use_column_width=True)
# st.write("")
# st.write("Generating report...")
#predictions = inference(args, uploaded_file)
# if predictions:
# st.write("Generated Report:")
# for pred in predictions:
# print("Generated Report", pred)
# st.write(pred)
# else:
# st.write("Failed to generate report.")
if __name__ == '__main__':
main()
|