Spaces:
Running
Running
File size: 2,307 Bytes
3da4879 877a841 3da4879 ae753cc 3da4879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import os
from peft import PeftModel
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
model_name = 'google/vit-large-patch16-224'
adapter = 'monsoon-nlp/eyegazer-vit-binary'
image_processor = AutoImageProcessor.from_pretrained(model_name)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(image_processor.size["height"]),
CenterCrop(image_processor.size["height"]),
ToTensor(),
normalize,
]
)
model = AutoModelForImageClassification.from_pretrained(
model_name,
ignore_mismatched_sizes=True,
num_labels=2,
)
lora_model = PeftModel.from_pretrained(model, adapter)
def query(img):
pimg = val_transforms(img.convert("RGB"))
batch = pimg.unsqueeze(0)
op = lora_model(batch)
vals = op.logits.tolist()[0]
if vals[0] > vals[1]:
return "Predicted unaffected"
else:
return "Predicted affected to some degree"
iface = gr.Interface(
fn=query,
examples=[
# os.path.join(os.path.dirname(__file__), "images/i1.png"),
os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"),
os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"),
os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"),
os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"),
],
inputs=[
gr.Image(
image_mode='RGB',
sources=['upload', 'clipboard'],
type='pil',
label='Input Fundus Camera Image',
show_label=True,
),
],
outputs=[
gr.Markdown(value="", label="Predicted label"),
],
title="ViT retinopathy model",
description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice",
allow_flagging="never",
)
iface.launch()
|