Spaces:
Running
Running
import gradio as gr | |
import os | |
from peft import PeftModel | |
from PIL import Image | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from torchvision.transforms import ( | |
CenterCrop, | |
Compose, | |
Normalize, | |
RandomHorizontalFlip, | |
RandomResizedCrop, | |
Resize, | |
ToTensor, | |
) | |
model_name = 'google/vit-large-patch16-224' | |
adapter = 'monsoon-nlp/eyegazer-vit-binary' | |
image_processor = AutoImageProcessor.from_pretrained(model_name) | |
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | |
train_transforms = Compose( | |
[ | |
RandomResizedCrop(image_processor.size["height"]), | |
RandomHorizontalFlip(), | |
ToTensor(), | |
normalize, | |
] | |
) | |
val_transforms = Compose( | |
[ | |
Resize(image_processor.size["height"]), | |
CenterCrop(image_processor.size["height"]), | |
ToTensor(), | |
normalize, | |
] | |
) | |
model = AutoModelForImageClassification.from_pretrained( | |
model_name, | |
ignore_mismatched_sizes=True, | |
num_labels=2, | |
) | |
lora_model = PeftModel.from_pretrained(model, adapter) | |
def query(img): | |
pimg = val_transforms(img.convert("RGB")) | |
batch = pimg.unsqueeze(0) | |
op = lora_model(batch) | |
vals = op.logits.tolist()[0] | |
if vals[0] > vals[1]: | |
return "Predicted unaffected" | |
else: | |
return "Predicted affected to some degree" | |
iface = gr.Interface( | |
fn=query, | |
examples=[ | |
# os.path.join(os.path.dirname(__file__), "images/i1.png"), | |
os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"), | |
os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"), | |
os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"), | |
os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"), | |
], | |
inputs=[ | |
gr.Image( | |
image_mode='RGB', | |
sources=['upload', 'clipboard'], | |
type='pil', | |
label='Input Fundus Camera Image', | |
show_label=True, | |
), | |
], | |
outputs=[ | |
gr.Markdown(value="", label="Predicted label"), | |
], | |
title="ViT retinopathy model", | |
description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice", | |
allow_flagging="never", | |
) | |
iface.launch() | |