import gradio as gr import torch from PIL import Image import torchvision.transforms as transforms def img2label(left, right): left_img = Image.open(left).convert('RGB') right_img = Image.open(right).convert('RGB') # 将右眼底镜像反转 r2l = transforms.RandomHorizontalFlip(p=1) right_img = r2l(right_img) # 调整图片 left_img = my_transforms(left_img).to(device) right_img = my_transforms(right_img).to(device) # 读取模型 model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device) with torch.no_grad(): output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0)) output = torch.sigmoid(output.squeeze(0)) pred = output.cpu().numpy().tolist() return {LABELS[i]: pred[i] for i in range(len(pred))} if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 标题 title = "基于眼底图像的智能健康诊断分析系统" # 标题下的描述,支持md格式 description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \ "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类" # transforms设置 norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] my_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) LABELS = {0: '正常', 1: '糖尿病', 2: '青光眼', 3: '白内障', 4: '年龄性黄斑变性', 5: '高血压', 6: '病理性近视', 7: '其他疾病'} # left_img_dir = 'left.jpg' # right_img_dir = 'right.jpg' # r = img2label(left_img_dir, right_img_dir) demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label', title=title, description=description) demo.launch(share=True)