|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from typing import Tuple, Optional, Dict, List, Dict |
|
import glob |
|
from collections import defaultdict |
|
|
|
from transformers import (AutoImageProcessor, |
|
ResNetForImageClassification) |
|
|
|
from labelmap import DR_LABELMAP |
|
|
|
|
|
class App: |
|
""" Demonstration of the Diabetic Retinopathy model as a Gradio app. """ |
|
|
|
def __init__(self) -> None: |
|
""" Constructor. """ |
|
|
|
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/" |
|
|
|
path = f"release_ckpts/{ckpt_name}/inference/" |
|
|
|
self.image_processor = AutoImageProcessor.from_pretrained(path) |
|
|
|
self.model = ResNetForImageClassification.from_pretrained(path) |
|
|
|
example_lists = self._load_example_lists() |
|
|
|
device = 'GPU' if torch.cuda.is_available() else 'CPU' |
|
|
|
css = ".output-image, .input-image, .image-preview {height: 600px !important}" |
|
|
|
with gr.Blocks(css=css) as ui: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
predict_btn = gr.Button("Predict", size="lg") |
|
with gr.Row(): |
|
gr.Markdown(f"Running on {device}") |
|
with gr.Column(scale=4): |
|
|
|
output = gr.Label(num_top_classes=len(DR_LABELMAP), |
|
label="Retinopathy level prediction") |
|
with gr.Column(scale=4): |
|
gr.Markdown("![](https://media.githubusercontent.com/media/SDAIA-KAUST-AI/diabetic-retinopathy-detection/main/media/logo1.png)") |
|
with gr.Row(): |
|
with gr.Column(scale=9, min_width=100): |
|
image = gr.Image(label="Retina scan") |
|
with gr.Column(scale=1, min_width=150): |
|
for cls_id in range(len(example_lists)): |
|
label = DR_LABELMAP[cls_id] |
|
with gr.Tab(f"{cls_id} : {label}"): |
|
gr.Examples( |
|
example_lists[cls_id], |
|
inputs=[image], |
|
outputs=[output], |
|
fn=self.predict, |
|
examples_per_page=10, |
|
run_on_click=True) |
|
|
|
predict_btn.click( |
|
fn=self.predict, |
|
inputs=image, |
|
outputs=output, |
|
api_name="predict") |
|
|
|
self.ui = ui |
|
|
|
def launch(self) -> None: |
|
""" Launch the application, blocking. """ |
|
self.ui.queue().launch(share=True) |
|
|
|
def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]: |
|
""" Gradio callback for pricessing of an image. |
|
|
|
Args: |
|
image (Optional[np.ndarray]): Provided image. |
|
|
|
Returns: |
|
Dict[str, float]: Label-compatible dict. |
|
""" |
|
|
|
if image is None: |
|
return dict() |
|
cls_name, prob, probs = self._infer(image) |
|
message = f"Predicted class={cls_name}, prob={prob:.3f}" |
|
print(message) |
|
probs_dict = {f"{i} - {DR_LABELMAP[i]}": float(v) |
|
for i, v in enumerate(probs)} |
|
return probs_dict |
|
|
|
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]: |
|
""" Low-level method to perform neural network inference. |
|
|
|
Args: |
|
image_chw (np.ndarray): Provided image. |
|
|
|
Returns: |
|
Tuple[str, float, np.ndarray]: |
|
- Most probable class name |
|
- Probability of the most probable class name. |
|
- Probablilities of all classes in the order of |
|
being listed in the label map. |
|
""" |
|
|
|
assert isinstance(self.model, ResNetForImageClassification) |
|
|
|
inputs = self.image_processor(image_chw, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output = self.model(**inputs) |
|
|
|
logits_batch = output.logits |
|
assert len(logits_batch.shape) == 2 |
|
assert logits_batch.shape[0] == 1 |
|
logits = logits_batch[0] |
|
probs = torch.softmax(logits, dim=-1) |
|
predicted_label = int(probs.argmax(-1).item()) |
|
prob = probs[predicted_label].item() |
|
cls_name = self.model.config.id2label[predicted_label] |
|
return cls_name, prob, probs.numpy() |
|
|
|
@staticmethod |
|
def _load_example_lists() -> Dict[int, List[str]]: |
|
""" Load example retina images from disk. |
|
|
|
Returns: |
|
Dict[int, List[str]]: Dictionary of cls_id -> list of images paths. |
|
""" |
|
|
|
example_flat_list = glob.glob("demo_data/demo/**/*.jpeg") |
|
|
|
example_lists: Dict[int, List[str]] = defaultdict(list) |
|
for path in example_flat_list: |
|
dir, _ = os.path.split(path) |
|
_, subdir = os.path.split(dir) |
|
try: |
|
cls_id = int(subdir) |
|
except ValueError: |
|
print(f"Cannot parse path {path}") |
|
continue |
|
example_lists[cls_id].append(path) |
|
return example_lists |
|
|
|
|
|
def main(): |
|
""" App entry point. """ |
|
app = App() |
|
app.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|