dongsheng commited on
Commit
3b16162
1 Parent(s): 1aa91f1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -5,15 +5,15 @@ import torchvision.transforms as transforms
5
 
6
 
7
  def img2label(left, right):
8
- left_img = Image.open(left).convert('RGB')
9
- right_img = Image.open(right).convert('RGB')
10
  # 将右眼底镜像反转
11
  r2l = transforms.RandomHorizontalFlip(p=1)
12
- right_img = r2l(right_img)
13
 
14
  # 调整图片
15
- left_img = my_transforms(left_img).to(device)
16
- right_img = my_transforms(right_img).to(device)
17
 
18
  # 读取模型
19
  model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
@@ -22,9 +22,19 @@ def img2label(left, right):
22
  output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
23
 
24
  output = torch.sigmoid(output.squeeze(0))
25
- pred = output.cpu().numpy().tolist()
 
 
 
26
 
27
- return {LABELS[i]: pred[i] for i in range(len(pred))}
 
 
 
 
 
 
 
28
 
29
 
30
  if __name__ == '__main__':
@@ -54,9 +64,11 @@ if __name__ == '__main__':
54
  6: '病理性近视',
55
  7: '其他疾病'}
56
 
57
- # left_img_dir = 'left.jpg'
58
- # right_img_dir = 'right.jpg'
 
59
  # r = img2label(left_img_dir, right_img_dir)
60
- demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label',
 
61
  title=title, description=description)
62
- demo.launch()
 
5
 
6
 
7
  def img2label(left, right):
8
+ left = Image.fromarray(left.astype('uint8'), 'RGB')
9
+ right = Image.fromarray(right.astype('uint8'), 'RGB')
10
  # 将右眼底镜像反转
11
  r2l = transforms.RandomHorizontalFlip(p=1)
12
+ right = r2l(right)
13
 
14
  # 调整图片
15
+ left_img = my_transforms(left).to(device)
16
+ right_img = my_transforms(right).to(device)
17
 
18
  # 读取模型
19
  model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
 
22
  output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
23
 
24
  output = torch.sigmoid(output.squeeze(0))
25
+ # pred = output.cpu().numpy().tolist()
26
+ # return {LABELS[i]: pred[i] for i in range(len(pred))}
27
+ pred = torch.nonzero(output > 0.4).view(-1)
28
+ pred = pred.cpu().numpy().tolist()
29
 
30
+ if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
31
+ return LABELS[0]
32
+ res = ''
33
+ for i in pred:
34
+ if i == 0:
35
+ continue
36
+ res += ', ' + LABELS[i]
37
+ return '目前的身体状态:' + res[2:]
38
 
39
 
40
  if __name__ == '__main__':
 
64
  6: '病理性近视',
65
  7: '其他疾病'}
66
 
67
+ left_img_dir = 'left.jpg'
68
+ right_img_dir = 'right.jpg'
69
+ examples = [[left_img_dir, right_img_dir]]
70
  # r = img2label(left_img_dir, right_img_dir)
71
+ demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()],
72
+ outputs="text", examples=examples,
73
  title=title, description=description)
74
+ demo.launch(share=True)