Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import torchvision.transforms as transforms | |
def img2label(left, right): | |
left_img = Image.open(left).convert('RGB') | |
right_img = Image.open(right).convert('RGB') | |
# 将右眼底镜像反转 | |
r2l = transforms.RandomHorizontalFlip(p=1) | |
right_img = r2l(right_img) | |
# 调整图片 | |
left_img = my_transforms(left_img).to(device) | |
right_img = my_transforms(right_img).to(device) | |
# 读取模型 | |
model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device) | |
with torch.no_grad(): | |
output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0)) | |
output = torch.sigmoid(output.squeeze(0)) | |
pred = output.cpu().numpy().tolist() | |
return {LABELS[i]: pred[i] for i in range(len(pred))} | |
if __name__ == '__main__': | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 标题 | |
title = "基于眼底图像的智能健康诊断分析系统" | |
# 标题下的描述,支持md格式 | |
description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \ | |
"包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类" | |
# transforms设置 | |
norm_mean = [0.485, 0.456, 0.406] | |
norm_std = [0.229, 0.224, 0.225] | |
my_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(norm_mean, norm_std) | |
]) | |
LABELS = {0: '正常', | |
1: '糖尿病', | |
2: '青光眼', | |
3: '白内障', | |
4: '年龄性黄斑变性', | |
5: '高血压', | |
6: '病理性近视', | |
7: '其他疾病'} | |
# left_img_dir = 'left.jpg' | |
# right_img_dir = 'right.jpg' | |
# r = img2label(left_img_dir, right_img_dir) | |
demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label', | |
title=title, description=description) | |
demo.launch(share=True) | |