hsshin98
commited on
Commit
·
aff8d56
1
Parent(s):
f98e690
prompt
Browse files- app.py +13 -6
- cat_seg/modeling/transformer/cat_seg_predictor.py +10 -10
- demo/predictor.py +3 -2
app.py
CHANGED
@@ -41,7 +41,6 @@ def setup_cfg(args):
|
|
41 |
add_cat_seg_config(cfg)
|
42 |
cfg.merge_from_file(args.config_file)
|
43 |
cfg.merge_from_list(args.opts)
|
44 |
-
cfg.MODEL.DEVICE = "cpu"
|
45 |
cfg.freeze()
|
46 |
return cfg
|
47 |
|
@@ -67,7 +66,10 @@ def get_parser():
|
|
67 |
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
|
68 |
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
|
69 |
"TEST.SLIDING_WINDOW", "True",
|
70 |
-
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"
|
|
|
|
|
|
|
71 |
nargs=argparse.REMAINDER,
|
72 |
)
|
73 |
return parser
|
@@ -75,7 +77,7 @@ def get_parser():
|
|
75 |
def save_masks(preds, text):
|
76 |
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
|
77 |
for i, t in enumerate(text):
|
78 |
-
dir = f"
|
79 |
mask = preds == i
|
80 |
cv2.imwrite(dir, mask * 255)
|
81 |
|
@@ -84,7 +86,7 @@ def predict(image, text):
|
|
84 |
cfg = setup_cfg(args)
|
85 |
demo = VisualizationDemo(cfg, text=text)
|
86 |
predictions, visualized_output = demo.run_on_image(image)
|
87 |
-
#
|
88 |
canvas = fc(visualized_output.fig)
|
89 |
canvas.draw()
|
90 |
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
|
@@ -97,7 +99,12 @@ if __name__ == "__main__":
|
|
97 |
|
98 |
iface = gr.Interface(
|
99 |
fn=predict,
|
100 |
-
inputs=[gr.Image(), gr.Textbox(placeholder=
|
101 |
outputs="image",
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
iface.launch()
|
|
|
41 |
add_cat_seg_config(cfg)
|
42 |
cfg.merge_from_file(args.config_file)
|
43 |
cfg.merge_from_list(args.opts)
|
|
|
44 |
cfg.freeze()
|
45 |
return cfg
|
46 |
|
|
|
66 |
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
|
67 |
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
|
68 |
"TEST.SLIDING_WINDOW", "True",
|
69 |
+
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
|
70 |
+
"MODEL.DEVICE", "cpu",
|
71 |
+
"MODEL.PROMPT_ENSEMBLE_TYPE", "single"
|
72 |
+
],
|
73 |
nargs=argparse.REMAINDER,
|
74 |
)
|
75 |
return parser
|
|
|
77 |
def save_masks(preds, text):
|
78 |
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
|
79 |
for i, t in enumerate(text):
|
80 |
+
dir = f"mask_{t}.png"
|
81 |
mask = preds == i
|
82 |
cv2.imwrite(dir, mask * 255)
|
83 |
|
|
|
86 |
cfg = setup_cfg(args)
|
87 |
demo = VisualizationDemo(cfg, text=text)
|
88 |
predictions, visualized_output = demo.run_on_image(image)
|
89 |
+
#save_masks(predictions, text.split(','))
|
90 |
canvas = fc(visualized_output.fig)
|
91 |
canvas.draw()
|
92 |
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
|
|
|
99 |
|
100 |
iface = gr.Interface(
|
101 |
fn=predict,
|
102 |
+
inputs=[gr.Image(), gr.Textbox(placeholder='cat, person, background')],
|
103 |
outputs="image",
|
104 |
+
description="""## CAT-Seg Demo
|
105 |
+
Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
|
106 |
+
|
107 |
+
Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
|
108 |
+
|
109 |
+
To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
|
110 |
iface.launch()
|
cat_seg/modeling/transformer/cat_seg_predictor.py
CHANGED
@@ -50,13 +50,13 @@ class CATSegPredictor(nn.Module):
|
|
50 |
|
51 |
import json
|
52 |
# use class_texts in train_forward, and test_class_texts in test_forward
|
53 |
-
with open(train_class_json, 'r') as f_in:
|
54 |
-
|
55 |
-
with open(test_class_json, 'r') as f_in:
|
56 |
-
|
57 |
-
assert self.class_texts != None
|
58 |
-
if self.test_class_texts == None:
|
59 |
-
|
60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
self.device = device
|
62 |
self.tokenizer = None
|
@@ -84,12 +84,12 @@ class CATSegPredictor(nn.Module):
|
|
84 |
prompt_templates = ['A photo of a {} in the scene',]
|
85 |
else:
|
86 |
raise NotImplementedError
|
|
|
|
|
|
|
87 |
|
88 |
self.clip_model = clip_model.float()
|
89 |
self.clip_preprocess = clip_preprocess
|
90 |
-
|
91 |
-
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
92 |
-
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
93 |
|
94 |
transformer = Aggregator(
|
95 |
text_guidance_dim=text_guidance_dim,
|
|
|
50 |
|
51 |
import json
|
52 |
# use class_texts in train_forward, and test_class_texts in test_forward
|
53 |
+
#with open(train_class_json, 'r') as f_in:
|
54 |
+
# self.class_texts = json.load(f_in)
|
55 |
+
#with open(test_class_json, 'r') as f_in:
|
56 |
+
# self.test_class_texts = json.load(f_in)
|
57 |
+
#assert self.class_texts != None
|
58 |
+
#if self.test_class_texts == None:
|
59 |
+
# self.test_class_texts = self.class_texts
|
60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
self.device = device
|
62 |
self.tokenizer = None
|
|
|
84 |
prompt_templates = ['A photo of a {} in the scene',]
|
85 |
else:
|
86 |
raise NotImplementedError
|
87 |
+
|
88 |
+
#self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
89 |
+
#self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
90 |
|
91 |
self.clip_model = clip_model.float()
|
92 |
self.clip_preprocess = clip_preprocess
|
|
|
|
|
|
|
93 |
|
94 |
transformer = Aggregator(
|
95 |
text_guidance_dim=text_guidance_dim,
|
demo/predictor.py
CHANGED
@@ -43,8 +43,9 @@ class VisualizationDemo(object):
|
|
43 |
pred = self.predictor.model.sem_seg_head.predictor
|
44 |
pred.test_class_texts = text.split(',')
|
45 |
pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
|
46 |
-
imagenet_templates.IMAGENET_TEMPLATES,
|
47 |
-
|
|
|
48 |
self.metadata = ns()
|
49 |
self.metadata.stuff_classes = pred.test_class_texts
|
50 |
|
|
|
43 |
pred = self.predictor.model.sem_seg_head.predictor
|
44 |
pred.test_class_texts = text.split(',')
|
45 |
pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
|
46 |
+
#imagenet_templates.IMAGENET_TEMPLATES,
|
47 |
+
['A photo of a {} in the scene',],
|
48 |
+
pred.clip_model).permute(1, 0, 2).float().repeat(1, 80, 1)
|
49 |
self.metadata = ns()
|
50 |
self.metadata.stuff_classes = pred.test_class_texts
|
51 |
|