import gradio as gr import torch from PIL import Image import torchvision.transforms as transforms def img2label(left, right): left = Image.fromarray(left.astype('uint8'), 'RGB') right = Image.fromarray(right.astype('uint8'), 'RGB') # 将右眼底镜像反转 r2l = transforms.RandomHorizontalFlip(p=1) right = r2l(right) # 调整图片 left_img = my_transforms(left).to(device) right_img = my_transforms(right).to(device) # 读取模型 model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device) # 模型推理 with torch.no_grad(): output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0)) # 结果处理 output = torch.sigmoid(output.squeeze(0)) output_ = output.cpu().numpy().tolist() res_dict = {LABELS[i]: output_[i] for i in range(len(output_))} pred = torch.nonzero(output > 0.4).view(-1) pred = pred.cpu().numpy().tolist() if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0): return res_dict, LABELS[0] res = '' for i in pred: if i == 0: continue res += ', ' + LABELS[i] return res_dict, '目前的身体状态:' + res[2:] if __name__ == '__main__': device = torch.device("cpu") # 标题 title = "基于眼底图像的智能健康诊断分析系统" # 标题下的描述,支持md格式 description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \ "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类" # transforms设置 norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] my_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) # 标签设置 LABELS = {0: '正常', 1: '糖尿病', 2: '青光眼', 3: '白内障', 4: '年龄性黄斑变性', 5: '高血压', 6: '病理性近视', 7: '其他疾病'} left_img_dir = 'left.jpg' right_img_dir = 'right.jpg' examples = [[left_img_dir, right_img_dir]] # r = img2label(left_img_dir, right_img_dir) demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs=["label", "text"], examples=examples, title=title, description=description) demo.launch()