jhj0517 commited on
Commit
2e0064d
1 Parent(s): dfc0c1a

complete auto mask feature

Browse files
Files changed (1) hide show
  1. app.py +63 -40
app.py CHANGED
@@ -6,43 +6,66 @@ from modules.model_downloader import DEFAULT_MODEL_TYPE
6
  from modules.paths import OUTPUT_DIR
7
  from modules.utils import open_folder
8
 
9
- sam_inf = SamInference()
10
-
11
- with gr.Blocks() as app:
12
- with gr.Row():
13
- with gr.Column(scale=5):
14
- img_input = gr.Image(label="Input image here")
15
- with gr.Column(scale=5):
16
- dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE, choices=sam_inf.available_models)
17
- nb_points_per_side = gr.Number(label="points_per_side ", value=64)
18
- nb_points_per_batch = gr.Number(label="points_per_batch ", value=128)
19
- sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh ", value=0.7, minimum=0, maximum=1)
20
- sld_stability_score_thresh = gr.Slider(label="stability_score_thresh ", value=0.92, minimum=0,
21
- maximum=1)
22
- sld_stability_score_offset = gr.Slider(label="stability_score_offset ", value=0.7, minimum=0,
23
- maximum=1)
24
- nb_crop_n_layers = gr.Number(label="crop_n_layers ", value=1)
25
- sld_box_nms_thresh = gr.Slider(label="box_nms_thresh ", value=0.7, minimum=0,
26
- maximum=1)
27
- nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor ", value=2)
28
- nb_min_mask_region_area = gr.Number(label="min_mask_region_area ", value=25)
29
- cb_use_m2m = gr.Checkbox(label="use_m2m ", value=True)
30
-
31
- with gr.Row():
32
- btn_generate = gr.Button("GENERATE", variant="primary")
33
- with gr.Row():
34
- gallery_output = gr.Gallery(label="Output images will be shown here")
35
- with gr.Column():
36
- output_file = gr.File(label="Generated psd file", scale=9)
37
- btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
38
-
39
- params = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh, sld_stability_score_thresh,
40
- sld_stability_score_offset,
41
- nb_crop_n_layers, sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
42
- cb_use_m2m]
43
- btn_generate.click(fn=sam_inf.generate_mask_app,
44
- inputs=[img_input, dd_models] + params, outputs=[gallery_output, output_file])
45
- btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
46
- inputs=None, outputs=None)
47
-
48
- app.queue().launch(inbrowser=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from modules.paths import OUTPUT_DIR
7
  from modules.utils import open_folder
8
 
9
+
10
+ class App:
11
+ def __init__(self,
12
+ args = None):
13
+ self.app = gr.Blocks()
14
+ self.args = args
15
+ self.sam_inf = SamInference()
16
+
17
+ def launch(self):
18
+ with self.app:
19
+ with gr.Row():
20
+ with gr.Column(scale=5):
21
+ with gr.Tabs() as tabs_sources:
22
+ with gr.TabItem("Image Input"):
23
+ img_input = gr.Image(label="Input image here")
24
+ with gr.TabItem("Video Input"):
25
+ vid_input = gr.Image(label="Input video here")
26
+
27
+ with gr.Column(scale=5):
28
+ dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE, choices=self.sam_inf.available_models)
29
+
30
+ with gr.Accordion("Mask Parameters", open=False) as mask_hparams:
31
+ nb_points_per_side = gr.Number(label="points_per_side ", value=64, interactive=True)
32
+ nb_points_per_batch = gr.Number(label="points_per_batch ", value=128, interactive=True)
33
+ sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh ", value=0.7, minimum=0, maximum=1,
34
+ interactive=True)
35
+ sld_stability_score_thresh = gr.Slider(label="stability_score_thresh ", value=0.92, minimum=0,
36
+ maximum=1, interactive=True)
37
+ sld_stability_score_offset = gr.Slider(label="stability_score_offset ", value=0.7, minimum=0,
38
+ maximum=1)
39
+ nb_crop_n_layers = gr.Number(label="crop_n_layers ", value=1)
40
+ sld_box_nms_thresh = gr.Slider(label="box_nms_thresh ", value=0.7, minimum=0,
41
+ maximum=1)
42
+ nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor ", value=2)
43
+ nb_min_mask_region_area = gr.Number(label="min_mask_region_area ", value=25)
44
+ cb_use_m2m = gr.Checkbox(label="use_m2m ", value=True)
45
+
46
+ with gr.Row():
47
+ btn_generate = gr.Button("GENERATE", variant="primary")
48
+ with gr.Row():
49
+ gallery_output = gr.Gallery(label="Output images will be shown here")
50
+ with gr.Column():
51
+ output_file = gr.File(label="Generated psd file", scale=9)
52
+ btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
53
+
54
+ sources = [img_input or vid_input]
55
+ model_params = [dd_models]
56
+ auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
57
+ sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
58
+ sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
59
+ cb_use_m2m]
60
+
61
+ btn_generate.click(fn=self.sam_inf.generate_mask_app,
62
+ inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
63
+ btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
64
+ inputs=None, outputs=None)
65
+
66
+ self.app.queue().launch(inbrowser=True)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ app = App()
71
+ app.launch()