|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
pipe = pipeline("image-classification", "trpakov/vit-pneumonia") |
|
|
|
|
|
def classify_image(image): |
|
outputs = pipe(image) |
|
outputs = { |
|
x["label"]: x["score"] for x in sorted(outputs, key=lambda x: x["label"]) |
|
} |
|
return outputs |
|
|
|
|
|
with gr.Blocks( |
|
title="ViT Chest X-ray Classification", |
|
) as demo: |
|
gr.Markdown("# ViT Chest X-ray Pneumonia Classification") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
"Classify chest x-ray scans as either having or not having pneumonia" |
|
) |
|
input_image = gr.Image(type="pil") |
|
classify_button = gr.Button("Classify!") |
|
with gr.Column(): |
|
output_label = gr.Label(label="Probabilities", num_top_classes=2) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
"./samples", |
|
inputs=input_image, |
|
outputs=output_label, |
|
cache_examples=True, |
|
fn=classify_image, |
|
run_on_click=True, |
|
) |
|
|
|
classify_button.click(fn=classify_image, inputs=input_image, outputs=output_label) |
|
|
|
|
|
demo.launch(debug=True, enable_queue=True) |
|
|