fundus_img / app.py
dongsheng's picture
Upload app.py
9ed6175
raw
history blame
2.63 kB
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
def img2label(left, right):
left = Image.fromarray(left.astype('uint8'), 'RGB')
right = Image.fromarray(right.astype('uint8'), 'RGB')
# 将右眼底镜像反转
r2l = transforms.RandomHorizontalFlip(p=1)
right = r2l(right)
# 调整图片
left_img = my_transforms(left).to(device)
right_img = my_transforms(right).to(device)
# 读取模型
model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
with torch.no_grad():
output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
output = torch.sigmoid(output.squeeze(0))
output_ = output.cpu().numpy().tolist()
res_dict = {LABELS[i]: output_[i] for i in range(len(output_))}
pred = torch.nonzero(output > 0.4).view(-1)
pred = pred.cpu().numpy().tolist()
if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
return LABELS[0], res_dict
res = ''
for i in pred:
if i == 0:
continue
res += ', ' + LABELS[i]
return '目前的身体状态:' + res[2:], res_dict
if __name__ == '__main__':
device = torch.device("cpu")
# 标题
title = "基于眼底图像的智能健康诊断分析系统"
# 标题下的描述,支持md格式
description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
"包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"
# transforms设置
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
my_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
LABELS = {0: '正常',
1: '糖尿病',
2: '青光眼',
3: '白内障',
4: '年龄性黄斑变性',
5: '高血压',
6: '病理性近视',
7: '其他疾病'}
left_img_dir = 'left.jpg'
right_img_dir = 'right.jpg'
examples = [[left_img_dir, right_img_dir]]
# r = img2label(left_img_dir, right_img_dir)
demo = gr.Interface(fn=img2label,
inputs=[gr.inputs.Image(), gr.inputs.Image()],
outputs=["text", "label"],
examples=examples, title=title, description=description)
demo.launch()