rashmi commited on
Commit
1553854
·
1 Parent(s): e27dc34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -7,14 +7,10 @@ import gradio as gr
7
  import torch, torchvision
8
  print(torch.__version__, torch.cuda.is_available())
9
  assert torch.__version__.startswith("1.9") # please manually install torch 1.9 if Colab changes its default version
10
- # Some basic setup:
11
- # Setup detectron2 logger
12
  import detectron2
13
  from detectron2.utils.logger import setup_logger
14
- # import some common libraries
15
  import numpy as np
16
  import os, json, random
17
- # import some common detectron2 utilities
18
  from detectron2 import model_zoo
19
  from detectron2.engine import DefaultPredictor
20
  from detectron2.config import get_cfg
@@ -25,35 +21,39 @@ from matplotlib import pyplot as plt
25
 
26
  cfg = get_cfg()
27
  cfg.MODEL.DEVICE='cpu'
28
- # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
29
  cfg.INPUT.MASK_FORMAT='bitmask'
30
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
31
  cfg.TEST.DETECTIONS_PER_IMAGE = 1000
32
  cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
33
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
34
- # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
35
  cfg.MODEL.WEIGHTS = "model_final.pth"
36
 
37
  predictor = DefaultPredictor(cfg)
38
 
39
 
40
  def inference(img):
 
41
  im = np.asarray(Image.open(img).convert('RGB'))
42
  outputs = predictor(im)
43
-
44
  take = outputs['instances'].scores >= 0.5 #Threshold
45
  pred_masks = outputs['instances'].pred_masks[take].cpu().numpy()
 
46
 
47
  mask = np.stack(pred_masks)
48
  mask = np.any(mask == 1, axis=0)
49
 
50
  p = plt.imshow(im,cmap='gray')
51
- p1 = plt.imshow(mask, alpha=0.4)
 
 
 
52
 
53
  return plt
54
 
55
 
56
 
 
57
  title = "Sartorius Cell Instance Segmentation"
58
  description = "Sartorius Cell Instance Segmentation Demo: Current Kaggle competition - kaggle.com/c/sartorius-cell-instance-segmentation"
59
  article = "<p style='text-align: center'><a href='https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/' target='_blank'>Detectron2: A PyTorch-based modular object detection library</a> | <a href='https://github.com/facebookresearch/detectron2' target='_blank'>Github Repo</a></p>"
 
7
  import torch, torchvision
8
  print(torch.__version__, torch.cuda.is_available())
9
  assert torch.__version__.startswith("1.9") # please manually install torch 1.9 if Colab changes its default version
 
 
10
  import detectron2
11
  from detectron2.utils.logger import setup_logger
 
12
  import numpy as np
13
  import os, json, random
 
14
  from detectron2 import model_zoo
15
  from detectron2.engine import DefaultPredictor
16
  from detectron2.config import get_cfg
 
21
 
22
  cfg = get_cfg()
23
  cfg.MODEL.DEVICE='cpu'
 
24
  cfg.INPUT.MASK_FORMAT='bitmask'
25
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
26
  cfg.TEST.DETECTIONS_PER_IMAGE = 1000
27
  cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
28
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
 
29
  cfg.MODEL.WEIGHTS = "model_final.pth"
30
 
31
  predictor = DefaultPredictor(cfg)
32
 
33
 
34
  def inference(img):
35
+ class_names = ['astro', 'cort', 'sh-sy5y']
36
  im = np.asarray(Image.open(img).convert('RGB'))
37
  outputs = predictor(im)
38
+ pred_classes = outputs['instances'].pred_classes.cpu().numpy().tolist()
39
  take = outputs['instances'].scores >= 0.5 #Threshold
40
  pred_masks = outputs['instances'].pred_masks[take].cpu().numpy()
41
+ pred_class = max(set(pred_classes), key=pred_classes.count)
42
 
43
  mask = np.stack(pred_masks)
44
  mask = np.any(mask == 1, axis=0)
45
 
46
  p = plt.imshow(im,cmap='gray')
47
+ p = plt.imshow(mask, alpha=0.4)
48
+ p = plt.xticks(fontsize=8)
49
+ p = plt.yticks(fontsize=8)
50
+ p = plt.title("cell type: " + class_names[pred_class])
51
 
52
  return plt
53
 
54
 
55
 
56
+
57
  title = "Sartorius Cell Instance Segmentation"
58
  description = "Sartorius Cell Instance Segmentation Demo: Current Kaggle competition - kaggle.com/c/sartorius-cell-instance-segmentation"
59
  article = "<p style='text-align: center'><a href='https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/' target='_blank'>Detectron2: A PyTorch-based modular object detection library</a> | <a href='https://github.com/facebookresearch/detectron2' target='_blank'>Github Repo</a></p>"