import os from pprint import pprint from configs.config import parser from dataset.data_module import DataModule from models.R2GenGPT import R2GenGPT import torch from transformers import BertTokenizer, AutoImageProcessor from PIL import Image import numpy as np import streamlit as st from lightning.pytorch import seed_everything # Initialize the app # st.title("Chest X-ray Report Generator") # Function to load the model def load_model(args): model = R2GenGPT(args) model.eval() model.freeze() return model # Function to parse image def _parse_image(vit_feature_extractor, img): pixel_values = vit_feature_extractor(img, return_tensors="pt").pixel_values return pixel_values[0] # Function to generate predictions def generate_predictions(image_path, vit_feature_extractor, model): model.llama_tokenizer.padding_side = "right" with Image.open(image_path) as pil: array = np.array(pil, dtype=np.uint8) if array.shape[-1] != 3 or len(array.shape) != 3: array = np.array(pil.convert("RGB"), dtype=np.uint8) image = _parse_image(vit_feature_extractor, array) image = image.unsqueeze(0) # image = image[None, :] image = image.to(device='cuda:0') print("Model Encoding for Image: ", model.encode_img(image)) try: img_embeds, atts_img = model.encode_img(image) print("Image embeddings in try blk", img_embeds) print("Try block for Image Embeddings \n") except Exception as e: st.error(e) print(st.error(e)) print("Except block for Image embeddings \n") # return [] img_embeds = model.layer_norm(img_embeds) img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img) print("Image embeddings: ", img_embeds) batch_size = img_embeds.shape[0] print("Batch size printed: ", batch_size) bos = torch.ones([batch_size, 1], dtype=atts_img.dtype, device=atts_img.device) * model.llama_tokenizer.bos_token_id bos_embeds = model.embed_tokens(bos) atts_bos = atts_img[:, :1] print("Attention: ", atts_bos) inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) print("Shape of Input emb", inputs_embeds) inputs_embeds = inputs_embeds.type(torch.float16) attention_mask = torch.cat([atts_bos, atts_img], dim=1) print("Shape of Attention mask: ", attention_mask) try: with torch.no_grad(): outputs = model.llama_model.generate(inputs_embeds=inputs_embeds) print("output", outputs) except Exception as e: st.error(e) return [] hypo = [model.decode(i) for i in outputs] print("Generated Report :", hypo) return hypo # Function to perform inference def inference(args, uploaded_file): model = load_model(args) vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model) with open("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", "wb") as f: f.write(uploaded_file.getbuffer()) predictions = generate_predictions("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg", vit_feature_extractor, model) print("Predictions: ", predictions) os.remove("/workspace/p10_p10046166_s50051329_427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg") return predictions # Main function def main(): #parser = argparse.ArgumentParser() # other arguments #parser.add_argument('--file', type=open, action=LoadFromFile) args = parser.parse_args() pprint(vars(args)) seed_everything(42, workers=True) # File uploader for image model = load_model(args) vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model) predictions = generate_predictions("/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg", vit_feature_extractor, model) print("Predictions: ", predictions) print("Inference: ", inference(args, "/workspace/p10_p10046166_s57379357_6e511483-c7e1601c-76890b2f-b0c6b55d-e53bcbf6.jpg")) #uploaded_file = st.file_uploader("Choose a chest X-ray image...", type="jpg") #if uploaded_file is not None: # st.image(uploaded_file, caption='Uploaded Image.', use_column_width=True) # st.write("") # st.write("Generating report...") #predictions = inference(args, uploaded_file) # if predictions: # st.write("Generated Report:") # for pred in predictions: # print("Generated Report", pred) # st.write(pred) # else: # st.write("Failed to generate report.") if __name__ == '__main__': main()