ankitkupadhyay's picture
Update app.py
0b14df6 verified
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr
# Model definition and setup
class VisionLanguageModel(nn.Module):
def __init__(self):
super(VisionLanguageModel, self).__init__()
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.language_model = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(
self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
2 # Number of classes: benign or malignant
)
def forward(self, input_ids, attention_mask, pixel_values):
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_pooled_output = vision_outputs.pooler_output
language_outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask
)
language_pooled_output = language_outputs.pooler_output
combined_features = torch.cat(
(vision_pooled_output, language_pooled_output),
dim=1
)
logits = self.classifier(combined_features)
return logits
model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
def predict(image, text_input):
image = feature_extractor(images=image, return_tensors="pt").pixel_values
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
pixel_values=image
)
_, prediction = torch.max(outputs, dim=1)
return prediction.item() # 1 for Malignant, 0 for Benign
# Enhanced UI with black text
with gr.Blocks(css="""
body {
color: black;
}
.benign, .malignant {
background-color: white;
border: 1px solid lightgray;
padding: 10px;
border-radius: 5px;
color: black;
}
.benign.correct, .malignant.correct {
background-color: lightgreen;
color: black;
}
""") as demo:
gr.Markdown(
"""
# 🩺 SKIN LESION CLASSIFICATION
Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
with gr.Column(scale=1):
gr.Markdown("## PREDICTION RESULTS")
benign_output = gr.HTML("<div class='benign'>Benign</div>")
malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
def display_prediction(image, text_input):
prediction = predict(image, text_input)
benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
return benign_html, malignant_html
# Submit button and prediction outputs
submit_btn = gr.Button("Get Prediction")
submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])
demo.launch()