fundus_img / app.py
dongsheng's picture
Upload app.py
7bd876a
raw
history blame
2.2 kB
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
def img2label(left, right):
left_img = Image.open(left).convert('RGB')
right_img = Image.open(right).convert('RGB')
# 将右眼底镜像反转
r2l = transforms.RandomHorizontalFlip(p=1)
right_img = r2l(right_img)
# 调整图片
left_img = my_transforms(left_img).to(device)
right_img = my_transforms(right_img).to(device)
# 读取模型
model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
with torch.no_grad():
output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
output = torch.sigmoid(output.squeeze(0))
pred = output.cpu().numpy().tolist()
return {LABELS[i]: pred[i] for i in range(len(pred))}
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 标题
title = "基于眼底图像的智能健康诊断分析系统"
# 标题下的描述,支持md格式
description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
"包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"
# transforms设置
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
my_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
LABELS = {0: '正常',
1: '糖尿病',
2: '青光眼',
3: '白内障',
4: '年龄性黄斑变性',
5: '高血压',
6: '病理性近视',
7: '其他疾病'}
# left_img_dir = 'left.jpg'
# right_img_dir = 'right.jpg'
# r = img2label(left_img_dir, right_img_dir)
demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label',
title=title, description=description)
demo.launch(share=True)