dongsheng commited on
Commit
9ed6175
1 Parent(s): 529c378

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -22,19 +22,20 @@ def img2label(left, right):
22
  output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
23
 
24
  output = torch.sigmoid(output.squeeze(0))
25
- # pred = output.cpu().numpy().tolist()
26
- # return {LABELS[i]: pred[i] for i in range(len(pred))}
 
27
  pred = torch.nonzero(output > 0.4).view(-1)
28
  pred = pred.cpu().numpy().tolist()
29
 
30
  if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
31
- return LABELS[0]
32
  res = ''
33
  for i in pred:
34
  if i == 0:
35
  continue
36
  res += ', ' + LABELS[i]
37
- return '目前的身体状态:' + res[2:]
38
 
39
 
40
  if __name__ == '__main__':
@@ -68,6 +69,8 @@ if __name__ == '__main__':
68
  right_img_dir = 'right.jpg'
69
  examples = [[left_img_dir, right_img_dir]]
70
  # r = img2label(left_img_dir, right_img_dir)
71
- demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()],
72
- outputs="text", examples=examples, title=title, description=description)
 
 
73
  demo.launch()
 
22
  output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
23
 
24
  output = torch.sigmoid(output.squeeze(0))
25
+ output_ = output.cpu().numpy().tolist()
26
+ res_dict = {LABELS[i]: output_[i] for i in range(len(output_))}
27
+
28
  pred = torch.nonzero(output > 0.4).view(-1)
29
  pred = pred.cpu().numpy().tolist()
30
 
31
  if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
32
+ return LABELS[0], res_dict
33
  res = ''
34
  for i in pred:
35
  if i == 0:
36
  continue
37
  res += ', ' + LABELS[i]
38
+ return '目前的身体状态:' + res[2:], res_dict
39
 
40
 
41
  if __name__ == '__main__':
 
69
  right_img_dir = 'right.jpg'
70
  examples = [[left_img_dir, right_img_dir]]
71
  # r = img2label(left_img_dir, right_img_dir)
72
+ demo = gr.Interface(fn=img2label,
73
+ inputs=[gr.inputs.Image(), gr.inputs.Image()],
74
+ outputs=["text", "label"],
75
+ examples=examples, title=title, description=description)
76
  demo.launch()