Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import cv2 | |
import sys | |
import argparse | |
import numpy as np | |
import json | |
import torch | |
import torch.nn.functional as F | |
import detectron2.data.transforms as T | |
import torchvision | |
from collections import OrderedDict | |
from scipy import spatial | |
import matplotlib.pyplot as plt | |
from packaging import version | |
from detectron2.engine import DefaultPredictor | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.config import get_cfg | |
from detectron2 import model_zoo | |
from detectron2.data import Metadata | |
from detectron2.structures.boxes import Boxes | |
from detectron2.structures import Instances | |
from plots.plot_pca_point import plot_pca_point | |
from plots.plot_histogram_dist import plot_histogram_dist | |
from plots.plot_gradcam import plot_gradcam | |
def extract_features(model, img, box): | |
height, width = img.shape[1:3] | |
inputs = [{"image": img, "height": height, "width": width}] | |
with torch.no_grad(): | |
img = model.preprocess_image(inputs) | |
features = model.backbone(img.tensor) | |
features_ = [features[f] for f in model.roi_heads.box_in_features] | |
box_features = model.roi_heads.box_pooler(features_, [box]) | |
output_features = F.avg_pool2d(box_features, [7, 7]) | |
output_features = output_features.view(-1, 256) | |
return output_features | |
def forward_model_full(model, cfg, cv_img): | |
height, width = cv_img.shape[:2] | |
transform_gen = T.ResizeShortestEdge( | |
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | |
) | |
image = transform_gen.get_transform(cv_img).apply_image(cv_img) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
inputs = [{"image": image, "height": height, "width": width}] | |
with torch.no_grad(): | |
images = model.preprocess_image(inputs) | |
features = model.backbone(images.tensor) | |
proposals, _ = model.proposal_generator(images, features, None) | |
features_ = [features[f] for f in model.roi_heads.box_in_features] | |
box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals]) | |
box_head = model.roi_heads.box_head(box_features) | |
predictions = model.roi_heads.box_predictor(box_head) | |
output_features = F.avg_pool2d(box_features, [7, 7]) | |
output_features = output_features.view(-1, 256) | |
probs = model.roi_heads.box_predictor.predict_probs(predictions, proposals) | |
pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals) | |
pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances) | |
pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes) | |
instances = pred_instances[0]["instances"] | |
instances.set("probs", probs[0][pred_inds]) | |
instances.set("features", output_features[pred_inds]) | |
return instances, cv_img | |
def load_model(): | |
cfg = get_cfg() | |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 | |
cfg.MODEL.WEIGHTS = MODEL | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH | |
cfg.MODEL.DEVICE = "cpu" | |
metadata = Metadata() | |
metadata.set( | |
evaluator_type="coco", | |
thing_classes=["neoplastic", "aphthous", "traumatic"], | |
thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2} | |
) | |
predictor = DefaultPredictor(cfg) | |
model = predictor.model | |
return dict( | |
predictor=predictor, | |
model=model, | |
metadata=metadata, | |
cfg=cfg | |
) | |
def draw_box(file_name, box, type, model, resize_input=False): | |
height, width, channels = img.shape | |
pred_v = Visualizer(img[:, :, ::-1], model["metadata"], scale=1) | |
instances = Instances((height, width), pred_boxes=Boxes(torch.tensor(box).unsqueeze(0)), pred_classes=torch.tensor([type])) | |
pred_v = pred_v.draw_instance_predictions(instances) | |
pred = pred_v.get_image()[:, :, ::-1] | |
pred = cv2.resize(pred, (800, 800)) | |
return pred | |
def explain(img, model): | |
state.write("Loading features...") | |
database = json.load(open(FEATURES_DATABASE)) | |
state.write("Computing logits...") | |
instances, input = forward_model_full(model["model"], model["cfg"], img) | |
instances.remove("pred_masks") | |
pred_v = Visualizer(cv2.cvtColor(input, cv2.COLOR_BGR2RGB), model["metadata"], scale=1) | |
pred_v = pred_v.draw_instance_predictions(instances.to("cpu")) | |
pred = pred_v.get_image()[:, :, ::-1] | |
pred = cv2.resize(pred, (800, 800)) | |
pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB) | |
if version.parse(st.__version__) >= version.parse("1.11.0"): | |
tabs = st.tabs(["Result", "Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))]) | |
lesion_tabs = tabs[2:] | |
detection_tab = tabs[1] | |
with tabs[0]: | |
st.header("Image processed") | |
st.success("Use the tabs on the right to see the detected lesions and detailed explanations for each lesion") | |
else: | |
tabs = [st.container() for i in range(0, len(instances)+1)] | |
lesion_tabs = tabs[1:] | |
detection_tab = tabs[0] | |
state.write("Populating first tab...") | |
with detection_tab: | |
st.header("Detected lesions") | |
st.image(pred) | |
for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)): | |
state.write(f"Populating tab for lesion #{i}...") | |
healthy_prob = scores[-1].item() | |
scores = scores[:-1] | |
features = features.tolist() | |
with tab: | |
st.header(f"Lesion #{i}") | |
state.write(f"Populating classes for lesion #{i}...") | |
lesion_img = draw_box(img, box.cpu(), type, model) | |
lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB) | |
classes = ["healty", "neoplastic", "aphthous", "traumatic"] | |
y_pos = np.arange(len(classes)) | |
probs = [healthy_prob] + scores.cpu().numpy().tolist() | |
probs_fig = plt.figure() | |
plt.bar(y_pos, probs, align="center") | |
plt.xticks(y_pos, classes) | |
plt.ylabel("Probability") | |
plt.title("Class") | |
st.subheader("Classification") | |
col1, col2 = st.columns(2) | |
col1.image(lesion_img) | |
col2.pyplot(probs_fig) | |
st.subheader("Feature space") | |
col1, col2 = st.columns(2) | |
state.write(f"Populating PCA for lesion #{i}...") | |
fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100) | |
col1.pyplot(fig) | |
state.write(f"Populating histogram for lesion #{i}...") | |
fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100) | |
col2.pyplot(fig) | |
state.write(f"Populating Gradcam++ for lesion #{i}...") | |
st.subheader("Gradcam++") | |
fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3") | |
st.pyplot(fig) | |
state.write("All done...") | |
FILE = "./test.jpg" | |
MODEL = "./models/model.pth" | |
PCA_MODEL = "./models/pca.pkl" | |
FEATURES_DATABASE = "./assets/features/features.json" | |
st.header("Explainable oral lesion detection") | |
st.markdown("""Demo for the paper [Explainable diagnosis of oral cancer via deep learning and case-based reasoning](https://mlpi.ing.unipi.it/doctoralai/) | |
Upload an image using the form below and click on "Process" | |
""") | |
FILE = st.file_uploader("Image", type=["jpg", "jpeg", "png"]) | |
TH = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.5) | |
process = st.button("Process") | |
state = st.empty() | |
if process: | |
state.write("Loading model...") | |
model = load_model() | |
nparr = np.fromstring(FILE.getvalue(), np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
#img = cv2.imread(FILE) | |
img = cv2.resize(img, (800, 800)) | |
explain(img, model) | |