Spaces:
Running
on
Zero
Running
on
Zero
LuJingyi-John
commited on
Commit
·
6678b47
1
Parent(s):
75fac02
Add Inpaint4Drag application with all components
Browse files- .gitignore +28 -0
- app.py +201 -0
- requirements.txt +21 -0
- utils/drag.py +297 -0
- utils/refine_mask.py +168 -0
- utils/ui_utils.py +271 -0
.gitignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output/
|
| 2 |
+
checkpoints/
|
| 3 |
+
drag_data/
|
| 4 |
+
webpage/
|
| 5 |
+
|
| 6 |
+
play.py
|
| 7 |
+
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
*.so
|
| 12 |
+
.Python
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
downloads/
|
| 17 |
+
eggs/
|
| 18 |
+
.eggs/
|
| 19 |
+
lib/
|
| 20 |
+
lib64/
|
| 21 |
+
parts/
|
| 22 |
+
sdist/
|
| 23 |
+
var/
|
| 24 |
+
wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
app.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.ui_utils import *
|
| 3 |
+
|
| 4 |
+
CANVAS_SIZE = 400
|
| 5 |
+
DEFAULT_GEN_SIZE = 512
|
| 6 |
+
|
| 7 |
+
def create_interface():
|
| 8 |
+
with gr.Blocks() as app:
|
| 9 |
+
# State variables
|
| 10 |
+
state = {
|
| 11 |
+
'canvas_size': gr.Number(value=CANVAS_SIZE, visible=False, precision=0),
|
| 12 |
+
'gen_size': gr.Number(value=DEFAULT_GEN_SIZE, visible=False, precision=0),
|
| 13 |
+
'points_list': gr.State(value=[]),
|
| 14 |
+
'inpaint_mask': gr.State(value=None)
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
with gr.Tab(label='Inpaint4Drag'):
|
| 18 |
+
with gr.Row():
|
| 19 |
+
# Draw Region Column
|
| 20 |
+
with gr.Column():
|
| 21 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">1. Draw Regions</p>""")
|
| 22 |
+
canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
|
| 23 |
+
with gr.Row():
|
| 24 |
+
fit_btn = gr.Button("Resize Image")
|
| 25 |
+
if_sam_box = gr.Checkbox(label='Refine mask (SAM)')
|
| 26 |
+
|
| 27 |
+
# Control Points Column
|
| 28 |
+
with gr.Column():
|
| 29 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">2. Control Points</p>""")
|
| 30 |
+
input_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=True)
|
| 31 |
+
with gr.Row():
|
| 32 |
+
undo_btn = gr.Button("Undo Point")
|
| 33 |
+
clear_btn = gr.Button("Clear Points")
|
| 34 |
+
|
| 35 |
+
# Results Column
|
| 36 |
+
with gr.Column():
|
| 37 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
|
| 38 |
+
output_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=False)
|
| 39 |
+
with gr.Row():
|
| 40 |
+
run_btn = gr.Button("Inpaint")
|
| 41 |
+
reset_btn = gr.Button("Reset All")
|
| 42 |
+
|
| 43 |
+
# Output Settings
|
| 44 |
+
with gr.Row("Generation Parameters"):
|
| 45 |
+
sam_ks = gr.Slider(minimum=11, maximum=51, value=21, step=2, label='How much to refine mask with SAM', interactive=True)
|
| 46 |
+
inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
|
| 47 |
+
output_path = gr.Textbox(value='output/app', label="Output path")
|
| 48 |
+
|
| 49 |
+
setup_events(
|
| 50 |
+
components={
|
| 51 |
+
'canvas': canvas,
|
| 52 |
+
'input_img': input_img,
|
| 53 |
+
'output_img': output_img,
|
| 54 |
+
'output_path': output_path,
|
| 55 |
+
'if_sam_box': if_sam_box,
|
| 56 |
+
'sam_ks': sam_ks,
|
| 57 |
+
'inpaint_ks': inpaint_ks,
|
| 58 |
+
},
|
| 59 |
+
state=state,
|
| 60 |
+
buttons={
|
| 61 |
+
'fit': fit_btn,
|
| 62 |
+
'undo': undo_btn,
|
| 63 |
+
'clear': clear_btn,
|
| 64 |
+
'run': run_btn,
|
| 65 |
+
'reset': reset_btn
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return app
|
| 70 |
+
|
| 71 |
+
def setup_events(components, state, buttons):
|
| 72 |
+
# Reset and clear events
|
| 73 |
+
def setup_reset_events():
|
| 74 |
+
buttons['reset'].click(
|
| 75 |
+
clear_all,
|
| 76 |
+
[state['canvas_size']],
|
| 77 |
+
[components['canvas'], components['input_img'], components['output_img'],
|
| 78 |
+
state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
components['canvas'].clear(
|
| 82 |
+
clear_all,
|
| 83 |
+
[state['canvas_size']],
|
| 84 |
+
[components['canvas'], components['input_img'], components['output_img'],
|
| 85 |
+
state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Image manipulation events
|
| 89 |
+
def setup_image_events():
|
| 90 |
+
buttons['fit'].click(
|
| 91 |
+
clear_point,
|
| 92 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
|
| 93 |
+
[components['input_img']]
|
| 94 |
+
).then(
|
| 95 |
+
resize,
|
| 96 |
+
[components['canvas'], state['gen_size'], state['canvas_size']],
|
| 97 |
+
[components['canvas'], components['input_img'], components['output_img']]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Canvas interaction events
|
| 101 |
+
def setup_canvas_events():
|
| 102 |
+
components['canvas'].edit(
|
| 103 |
+
visualize_user_drag,
|
| 104 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
|
| 105 |
+
[components['input_img']]
|
| 106 |
+
).then(
|
| 107 |
+
preview_out_image,
|
| 108 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 109 |
+
[components['output_img'], state['inpaint_mask']]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
components['if_sam_box'].change(
|
| 113 |
+
visualize_user_drag,
|
| 114 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
|
| 115 |
+
[components['input_img']]
|
| 116 |
+
).then(
|
| 117 |
+
preview_out_image,
|
| 118 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 119 |
+
[components['output_img'], state['inpaint_mask']]
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
components['sam_ks'].change(
|
| 123 |
+
visualize_user_drag,
|
| 124 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
|
| 125 |
+
[components['input_img']]
|
| 126 |
+
).then(
|
| 127 |
+
preview_out_image,
|
| 128 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 129 |
+
[components['output_img'], state['inpaint_mask']]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
components['inpaint_ks'].change(
|
| 133 |
+
visualize_user_drag,
|
| 134 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
|
| 135 |
+
[components['input_img']]
|
| 136 |
+
).then(
|
| 137 |
+
preview_out_image,
|
| 138 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 139 |
+
[components['output_img'], state['inpaint_mask']]
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Input image events
|
| 143 |
+
def setup_input_events():
|
| 144 |
+
components['input_img'].select(
|
| 145 |
+
add_point,
|
| 146 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
|
| 147 |
+
[components['input_img']]
|
| 148 |
+
).then(
|
| 149 |
+
preview_out_image,
|
| 150 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 151 |
+
[components['output_img'], state['inpaint_mask']]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Point manipulation events
|
| 155 |
+
def setup_point_events():
|
| 156 |
+
buttons['undo'].click(
|
| 157 |
+
undo_point,
|
| 158 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
|
| 159 |
+
[components['input_img']]
|
| 160 |
+
).then(
|
| 161 |
+
preview_out_image,
|
| 162 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 163 |
+
[components['output_img'], state['inpaint_mask']]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
buttons['clear'].click(
|
| 167 |
+
clear_point,
|
| 168 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
|
| 169 |
+
[components['input_img']]
|
| 170 |
+
).then(
|
| 171 |
+
preview_out_image,
|
| 172 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 173 |
+
[components['output_img'], state['inpaint_mask']]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Processing events
|
| 177 |
+
def setup_processing_events():
|
| 178 |
+
buttons['run'].click(
|
| 179 |
+
preview_out_image,
|
| 180 |
+
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 181 |
+
[components['output_img'], state['inpaint_mask']]
|
| 182 |
+
).then(
|
| 183 |
+
inpaint,
|
| 184 |
+
[components['output_img'], state['inpaint_mask']],
|
| 185 |
+
[components['output_img']]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Setup all events
|
| 189 |
+
setup_reset_events()
|
| 190 |
+
setup_image_events()
|
| 191 |
+
setup_canvas_events()
|
| 192 |
+
setup_input_events()
|
| 193 |
+
setup_point_events()
|
| 194 |
+
setup_processing_events()
|
| 195 |
+
|
| 196 |
+
def main():
|
| 197 |
+
app = create_interface()
|
| 198 |
+
app.queue().launch(share=True, debug=True)
|
| 199 |
+
|
| 200 |
+
if __name__ == '__main__':
|
| 201 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML Libraries
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
transformers
|
| 5 |
+
diffusers
|
| 6 |
+
accelerate
|
| 7 |
+
peft
|
| 8 |
+
xformers
|
| 9 |
+
|
| 10 |
+
# UI and Image Processing
|
| 11 |
+
gradio==3.47.1
|
| 12 |
+
opencv-python==4.8.0.76
|
| 13 |
+
Pillow
|
| 14 |
+
numpy
|
| 15 |
+
|
| 16 |
+
# Evaluation (Optional)
|
| 17 |
+
lpips
|
| 18 |
+
gdown
|
| 19 |
+
|
| 20 |
+
# EfficientViT-SAM (Optional)
|
| 21 |
+
git+https://github.com/mit-han-lab/efficientvit.git
|
utils/drag.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
def contour_to_points_and_mask(contour: np.ndarray, image_shape: tuple) -> tuple[np.ndarray, np.ndarray]:
|
| 7 |
+
"""Convert a contour to a set of points and binary mask.
|
| 8 |
+
|
| 9 |
+
This function takes a contour and creates both a binary mask and a list of points
|
| 10 |
+
that lie within the contour. The points are represented in (x, y) coordinates.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
contour (np.ndarray): Input contour of shape (N, 2) or (N, 1, 2) where N is
|
| 14 |
+
the number of points. Each point should be in (x, y) format.
|
| 15 |
+
image_shape (tuple): Shape of the output mask as (height, width).
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
tuple:
|
| 19 |
+
- np.ndarray: Array of points in (x, y) format with shape (M, 2),
|
| 20 |
+
where M is the number of points inside the contour.
|
| 21 |
+
Returns empty array of shape (0, 2) if contour is empty.
|
| 22 |
+
- np.ndarray: Binary mask of shape image_shape where pixels inside
|
| 23 |
+
the contour are 255 and outside are 0.
|
| 24 |
+
"""
|
| 25 |
+
if len(contour) == 0:
|
| 26 |
+
return np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)
|
| 27 |
+
|
| 28 |
+
# Create empty mask and fill the contour in the mask
|
| 29 |
+
mask = np.zeros(image_shape, dtype=np.uint8)
|
| 30 |
+
cv2.drawContours(mask, [contour.reshape(-1, 1, 2)], -1, 255, cv2.FILLED)
|
| 31 |
+
|
| 32 |
+
# Get points inside contour (y, x) and convert to (x, y)
|
| 33 |
+
points = np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]]
|
| 34 |
+
|
| 35 |
+
# Return empty array if no points found
|
| 36 |
+
if len(points) == 0:
|
| 37 |
+
points = np.zeros((0, 2), dtype=np.int32)
|
| 38 |
+
|
| 39 |
+
return points, mask
|
| 40 |
+
|
| 41 |
+
def find_control_points(
|
| 42 |
+
region_points: torch.Tensor,
|
| 43 |
+
source_control_points: torch.Tensor,
|
| 44 |
+
target_control_points: torch.Tensor,
|
| 45 |
+
distance_threshold: float = 1e-6
|
| 46 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
"""Find control points that match points within a region.
|
| 48 |
+
|
| 49 |
+
This function identifies which control points lie within or very close to
|
| 50 |
+
the specified region points. It matches source control points to region points
|
| 51 |
+
and returns both source and corresponding target control points that satisfy
|
| 52 |
+
the distance threshold criterion.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
region_points (torch.Tensor): Points defining a region, shape (N, 2).
|
| 56 |
+
Each point is in (x, y) format.
|
| 57 |
+
source_control_points (torch.Tensor): Source control points, shape (M, 2).
|
| 58 |
+
Each point is in (x, y) format.
|
| 59 |
+
target_control_points (torch.Tensor): Target control points, shape (M, 2).
|
| 60 |
+
Must have same first dimension as source_control_points.
|
| 61 |
+
distance_threshold (float, optional): Maximum distance for a point to be
|
| 62 |
+
considered matching. Defaults to 1e-6.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
tuple[torch.Tensor, torch.Tensor]:
|
| 66 |
+
- Matched source control points, shape (K, 2) where K ≤ M
|
| 67 |
+
- Corresponding target control points, shape (K, 2)
|
| 68 |
+
If no matches found or inputs empty, returns empty tensors of shape (0, 2)
|
| 69 |
+
"""
|
| 70 |
+
# Handle empty input cases
|
| 71 |
+
if len(region_points) == 0 or len(source_control_points) == 0:
|
| 72 |
+
return (
|
| 73 |
+
torch.zeros((0, 2), device=source_control_points.device),
|
| 74 |
+
torch.zeros((0, 2), device=target_control_points.device)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Calculate pairwise distances between source control points and region points
|
| 78 |
+
distances = torch.cdist(source_control_points, region_points)
|
| 79 |
+
|
| 80 |
+
# Find points that are within threshold distance of any region point
|
| 81 |
+
min_distances = distances.min(dim=1)[0]
|
| 82 |
+
matching_indices = min_distances < distance_threshold
|
| 83 |
+
|
| 84 |
+
# Return matched pairs of control points
|
| 85 |
+
return source_control_points[matching_indices], target_control_points[matching_indices]
|
| 86 |
+
|
| 87 |
+
def interpolate_points_with_weighted_directions(
|
| 88 |
+
points: torch.Tensor,
|
| 89 |
+
reference_points: torch.Tensor,
|
| 90 |
+
direction_vectors: torch.Tensor,
|
| 91 |
+
max_reference_points: int = 100,
|
| 92 |
+
num_nearest_neighbors: int = 4,
|
| 93 |
+
eps: float = 1e-6
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""Interpolate points based on weighted directions from nearest reference points.
|
| 96 |
+
|
| 97 |
+
This function moves each point by a weighted combination of direction vectors.
|
| 98 |
+
The weights are determined by the inverse distances to the nearest reference points.
|
| 99 |
+
If there are too many reference points, they are subsampled for efficiency.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
points (torch.Tensor): Points to interpolate, shape (N, 2) in (x, y) format
|
| 103 |
+
reference_points (torch.Tensor): Reference point locations, shape (M, 2)
|
| 104 |
+
direction_vectors (torch.Tensor): Direction vectors for each reference point,
|
| 105 |
+
shape (M, 2), must match reference_points first dimension
|
| 106 |
+
max_reference_points (int, optional): Maximum number of reference points to use.
|
| 107 |
+
If exceeded, points are subsampled. Defaults to 100.
|
| 108 |
+
num_nearest_neighbors (int, optional): Number of nearest neighbors to consider
|
| 109 |
+
for interpolation. Defaults to 4.
|
| 110 |
+
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
torch.Tensor: Interpolated points with shape (N, 2). If input points or
|
| 114 |
+
references are empty, returns the input points unchanged.
|
| 115 |
+
"""
|
| 116 |
+
# Handle empty input cases
|
| 117 |
+
if len(points) == 0 or len(reference_points) == 0:
|
| 118 |
+
return points
|
| 119 |
+
|
| 120 |
+
# Handle single reference point case
|
| 121 |
+
if len(reference_points) == 1:
|
| 122 |
+
return points + direction_vectors
|
| 123 |
+
|
| 124 |
+
# Subsample reference points if too many
|
| 125 |
+
if len(reference_points) > max_reference_points:
|
| 126 |
+
indices = torch.linspace(0, len(reference_points)-1, max_reference_points).long()
|
| 127 |
+
reference_points = reference_points[indices]
|
| 128 |
+
direction_vectors = direction_vectors[indices]
|
| 129 |
+
|
| 130 |
+
# Calculate distances to all reference points
|
| 131 |
+
distances = torch.cdist(points, reference_points)
|
| 132 |
+
|
| 133 |
+
# Find k nearest neighbors (k = min(num_nearest_neighbors, num_references))
|
| 134 |
+
k = min(num_nearest_neighbors, len(reference_points))
|
| 135 |
+
topk_distances, neighbor_indices = torch.topk(
|
| 136 |
+
distances,
|
| 137 |
+
k=k,
|
| 138 |
+
dim=1,
|
| 139 |
+
largest=False
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Calculate weights based on inverse distances
|
| 143 |
+
weights = 1.0 / (topk_distances + eps)
|
| 144 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 145 |
+
|
| 146 |
+
# Get directions for nearest neighbors and compute weighted average
|
| 147 |
+
neighbor_directions = direction_vectors[neighbor_indices]
|
| 148 |
+
weighted_directions = (weights.unsqueeze(-1) * neighbor_directions).sum(dim=1)
|
| 149 |
+
|
| 150 |
+
# Apply weighted directions and round to nearest integer
|
| 151 |
+
interpolated_points = (points + weighted_directions).round().float()
|
| 152 |
+
|
| 153 |
+
return interpolated_points
|
| 154 |
+
|
| 155 |
+
def get_points_within_image_bounds(
|
| 156 |
+
points: torch.Tensor,
|
| 157 |
+
image_shape: tuple[int, int]
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""Create a boolean mask for points that lie within image boundaries.
|
| 160 |
+
|
| 161 |
+
Identifies which points from the input tensor fall within valid image coordinates.
|
| 162 |
+
Points are assumed to be in (x, y) format, while image_shape is in (height, width) format.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
points (torch.Tensor): Points to check, shape (N, 2) in (x, y) format.
|
| 166 |
+
x coordinates correspond to width/columns
|
| 167 |
+
y coordinates correspond to height/rows
|
| 168 |
+
image_shape (tuple[int, int]): Image dimensions as (height, width).
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
torch.Tensor: Boolean mask of shape (N,) where True indicates the point
|
| 172 |
+
is within bounds. Returns empty tensor of shape (0,) if input is empty.
|
| 173 |
+
"""
|
| 174 |
+
# Handle empty input case
|
| 175 |
+
if len(points) == 0:
|
| 176 |
+
return torch.zeros(0, dtype=torch.bool, device=points.device)
|
| 177 |
+
|
| 178 |
+
# Unpack image dimensions
|
| 179 |
+
height, width = image_shape
|
| 180 |
+
|
| 181 |
+
# Check both x and y coordinates are within bounds
|
| 182 |
+
x_in_bounds = (points[:, 0] >= 0) & (points[:, 0] < width)
|
| 183 |
+
y_in_bounds = (points[:, 1] >= 0) & (points[:, 1] < height)
|
| 184 |
+
|
| 185 |
+
# Combine conditions
|
| 186 |
+
valid_points_mask = x_in_bounds & y_in_bounds
|
| 187 |
+
|
| 188 |
+
return valid_points_mask
|
| 189 |
+
|
| 190 |
+
def bi_warp(
|
| 191 |
+
region_mask: np.ndarray,
|
| 192 |
+
control_points: Union[np.ndarray, torch.Tensor],
|
| 193 |
+
kernel_size: int = 5
|
| 194 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 195 |
+
"""Generate corresponding source/target points and inpainting mask for masked regions.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
region_mask: Binary mask defining regions of interest (2D array with 0s and 1s)
|
| 199 |
+
control_points: Alternating source and target control points. Shape (N*2, 2)
|
| 200 |
+
kernel_size: Controls dilation kernel size. Must be odd number or 0.
|
| 201 |
+
Contour thickness will be (kernel_size-1)*2 (default: 5)
|
| 202 |
+
Set to 0 for no contour drawing and no dilation.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
tuple containing:
|
| 206 |
+
- Source points (M, 2)
|
| 207 |
+
- Target points (M, 2)
|
| 208 |
+
- Inpainting mask combined with target contour mask
|
| 209 |
+
"""
|
| 210 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 211 |
+
image_shape = region_mask.shape
|
| 212 |
+
|
| 213 |
+
# Ensure kernel_size is odd or 0
|
| 214 |
+
kernel_size = max(0, kernel_size)
|
| 215 |
+
if kernel_size > 0 and kernel_size % 2 == 0:
|
| 216 |
+
kernel_size += 1
|
| 217 |
+
|
| 218 |
+
# 1. Initialize tensors and masks
|
| 219 |
+
control_points = torch.tensor(control_points, dtype=torch.float32, device=device) if not isinstance(control_points, torch.Tensor) else control_points
|
| 220 |
+
source_control_points = control_points[0:-1:2]
|
| 221 |
+
target_control_points = control_points[1::2]
|
| 222 |
+
|
| 223 |
+
combined_source_mask = np.zeros(image_shape, dtype=np.uint8)
|
| 224 |
+
combined_target_mask = np.zeros(image_shape, dtype=np.uint8)
|
| 225 |
+
region_mask_binary = np.where(region_mask > 0, 1, 0).astype(np.uint8)
|
| 226 |
+
contour_mask = np.zeros(image_shape, dtype=np.uint8)
|
| 227 |
+
|
| 228 |
+
# 2. Process regions
|
| 229 |
+
contours = cv2.findContours(region_mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
| 230 |
+
all_source_points = []
|
| 231 |
+
all_target_points = []
|
| 232 |
+
|
| 233 |
+
for contour in contours:
|
| 234 |
+
if len(contour) == 0:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
# 3. Get source region points and mask
|
| 238 |
+
source_contour = torch.from_numpy(contour[:, 0, :]).float().to(device)
|
| 239 |
+
source_region_points, source_mask = contour_to_points_and_mask(contour[:, 0, :], image_shape)
|
| 240 |
+
source_mask = (source_mask > 0).astype(np.uint8)
|
| 241 |
+
|
| 242 |
+
if len(source_region_points) == 0:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
source_region_points = torch.from_numpy(source_region_points).float().to(device)
|
| 246 |
+
|
| 247 |
+
# 4. Transform points
|
| 248 |
+
source, target = find_control_points(source_region_points, source_control_points, target_control_points)
|
| 249 |
+
if len(source) == 0:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
directions = target - source
|
| 253 |
+
target_contour = interpolate_points_with_weighted_directions(source_contour, source, directions)
|
| 254 |
+
interpolated_target = interpolate_points_with_weighted_directions(source_region_points, source, directions)
|
| 255 |
+
|
| 256 |
+
# 5. Get target region points and mask
|
| 257 |
+
target_region_points, target_mask = contour_to_points_and_mask(target_contour.cpu().int().numpy(), image_shape)
|
| 258 |
+
target_mask = (target_mask > 0).astype(np.uint8)
|
| 259 |
+
|
| 260 |
+
if len(target_region_points) == 0:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
# Draw target contour
|
| 264 |
+
target_contour_np = target_contour.cpu().int().numpy()
|
| 265 |
+
if kernel_size > 0:
|
| 266 |
+
cv2.drawContours(contour_mask, [target_contour_np], -1, 1, kernel_size)
|
| 267 |
+
|
| 268 |
+
target_region = torch.from_numpy(target_region_points).float().to(device)
|
| 269 |
+
|
| 270 |
+
# 6. Apply reverse transformation
|
| 271 |
+
back_directions = source_region_points - interpolated_target
|
| 272 |
+
interpolated_source = interpolate_points_with_weighted_directions(target_region, interpolated_target, back_directions)
|
| 273 |
+
|
| 274 |
+
# 7. Filter valid points
|
| 275 |
+
valid_mask = get_points_within_image_bounds(interpolated_source, image_shape)
|
| 276 |
+
if valid_mask.any():
|
| 277 |
+
all_source_points.append(interpolated_source[valid_mask])
|
| 278 |
+
all_target_points.append(target_region[valid_mask])
|
| 279 |
+
combined_source_mask = np.logical_or(combined_source_mask, source_mask).astype(np.uint8)
|
| 280 |
+
combined_target_mask = np.logical_or(combined_target_mask, target_mask).astype(np.uint8)
|
| 281 |
+
|
| 282 |
+
# 8. Handle empty case
|
| 283 |
+
if not all_source_points:
|
| 284 |
+
return np.zeros((0, 2), dtype=np.int32), np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)
|
| 285 |
+
|
| 286 |
+
# 9. Finalize outputs
|
| 287 |
+
final_source = torch.cat(all_source_points).cpu().numpy().astype(np.int32)
|
| 288 |
+
final_target = torch.cat(all_target_points).cpu().numpy().astype(np.int32)
|
| 289 |
+
|
| 290 |
+
# Create and combine masks
|
| 291 |
+
inpaint_mask = np.logical_and(combined_source_mask, np.logical_not(combined_target_mask)).astype(np.uint8)
|
| 292 |
+
if kernel_size > 0:
|
| 293 |
+
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
|
| 294 |
+
inpaint_mask = cv2.dilate(inpaint_mask, kernel)
|
| 295 |
+
final_mask = np.logical_or(inpaint_mask, contour_mask).astype(np.uint8)
|
| 296 |
+
|
| 297 |
+
return final_source, final_target, final_mask
|
utils/refine_mask.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import urllib.request
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def download_model(checkpoint_path: str, model_name: str = "efficientvit_sam_l0.pt") -> str:
|
| 12 |
+
"""
|
| 13 |
+
Download the model checkpoint if not found locally.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
checkpoint_path: Local path where model should be saved
|
| 17 |
+
model_name: Name of the model file to download
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Path to the downloaded checkpoint
|
| 21 |
+
"""
|
| 22 |
+
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
base_url = "https://huggingface.co/mit-han-lab/efficientvit-sam/resolve/main"
|
| 25 |
+
model_url = f"{base_url}/{model_name}"
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
print(f"Downloading model from {model_url}...")
|
| 29 |
+
urllib.request.urlretrieve(model_url, checkpoint_path)
|
| 30 |
+
print(f"Model successfully downloaded to {checkpoint_path}")
|
| 31 |
+
return checkpoint_path
|
| 32 |
+
except Exception as e:
|
| 33 |
+
raise RuntimeError(f"Failed to download model: {str(e)}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SamMaskRefiner(nn.Module):
|
| 37 |
+
CHECKPOINT_DIR = 'checkpoints'
|
| 38 |
+
MODEL_CONFIGS = {
|
| 39 |
+
'l0': 'efficientvit_sam_l0.pt',
|
| 40 |
+
'l1': 'efficientvit_sam_l1.pt',
|
| 41 |
+
'l2': 'efficientvit_sam_l2.pt'
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def __init__(self, model_name: str = 'l0') -> None:
|
| 45 |
+
"""
|
| 46 |
+
Initialize SAM predictor with specified model version.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
model_name: Model version to use ('l0', 'l1', or 'l2'). Defaults to 'l0'.
|
| 50 |
+
|
| 51 |
+
Raises:
|
| 52 |
+
ValueError: If invalid model_name is provided
|
| 53 |
+
RuntimeError: If model loading fails after download attempt
|
| 54 |
+
"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
if model_name not in self.MODEL_CONFIGS:
|
| 58 |
+
raise ValueError(f"Invalid model_name. Choose from: {list(self.MODEL_CONFIGS.keys())}")
|
| 59 |
+
|
| 60 |
+
model_filename = self.MODEL_CONFIGS[model_name]
|
| 61 |
+
checkpoint_path = os.path.join(self.CHECKPOINT_DIR, model_filename)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
|
| 65 |
+
from efficientvit.sam_model_zoo import create_efficientvit_sam_model
|
| 66 |
+
except ImportError:
|
| 67 |
+
raise ImportError(
|
| 68 |
+
"Failed to import EfficientViT modules. Please ensure the package is installed:\n"
|
| 69 |
+
"pip install git+https://github.com/mit-han-lab/efficientvit.git"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not os.path.exists(checkpoint_path):
|
| 73 |
+
print(f"Checkpoint not found at {checkpoint_path}. Attempting to download...")
|
| 74 |
+
checkpoint_path = download_model(checkpoint_path, model_filename)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
model_type = f'efficientvit-sam-{model_name}'
|
| 78 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 79 |
+
self.model = create_efficientvit_sam_model(model_type, True, checkpoint_path).eval()
|
| 80 |
+
self.model = self.model.requires_grad_(False).to(device)
|
| 81 |
+
self.predictor = EfficientViTSamPredictor(self.model)
|
| 82 |
+
print(f"\033[92mEfficientViT-SAM model loaded from: {checkpoint_path}\033[0m")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 85 |
+
|
| 86 |
+
def sample_points_from_mask(self, mask: np.ndarray, max_points: int = 128) -> np.ndarray:
|
| 87 |
+
"""
|
| 88 |
+
Sample points uniformly from masked regions.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
mask: Binary mask array of shape (H, W) with 0-1 values.
|
| 92 |
+
max_points: Maximum number of points to sample.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
np.ndarray: Array of shape (N, 2) containing [x,y] coordinates.
|
| 96 |
+
"""
|
| 97 |
+
y_indices, x_indices = np.where(mask > 0.5)
|
| 98 |
+
total_points = len(y_indices)
|
| 99 |
+
|
| 100 |
+
if total_points <= max_points:
|
| 101 |
+
return np.stack([x_indices, y_indices], axis=1)
|
| 102 |
+
|
| 103 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
| 104 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
| 105 |
+
|
| 106 |
+
aspect_ratio = (x_max - x_min) / max(y_max - y_min, 1)
|
| 107 |
+
ny = int(np.sqrt(max_points / aspect_ratio))
|
| 108 |
+
nx = int(ny * aspect_ratio)
|
| 109 |
+
|
| 110 |
+
x_bins = np.linspace(x_min, x_max + 1, nx + 1, dtype=np.int32)
|
| 111 |
+
y_bins = np.linspace(y_min, y_max + 1, ny + 1, dtype=np.int32)
|
| 112 |
+
|
| 113 |
+
x_dig = np.digitize(x_indices, x_bins) - 1
|
| 114 |
+
y_dig = np.digitize(y_indices, y_bins) - 1
|
| 115 |
+
bin_indices = y_dig * nx + x_dig
|
| 116 |
+
unique_bins = np.unique(bin_indices)
|
| 117 |
+
|
| 118 |
+
points = []
|
| 119 |
+
for idx in unique_bins:
|
| 120 |
+
bin_y = idx // nx
|
| 121 |
+
bin_x = idx % nx
|
| 122 |
+
mask = (y_dig == bin_y) & (x_dig == bin_x)
|
| 123 |
+
|
| 124 |
+
if np.any(mask):
|
| 125 |
+
px = int(np.mean(x_indices[mask]))
|
| 126 |
+
py = int(np.mean(y_indices[mask]))
|
| 127 |
+
points.append([px, py])
|
| 128 |
+
|
| 129 |
+
points = np.array(points)
|
| 130 |
+
|
| 131 |
+
if len(points) > max_points:
|
| 132 |
+
indices = np.linspace(0, len(points) - 1, max_points, dtype=int)
|
| 133 |
+
points = points[indices]
|
| 134 |
+
|
| 135 |
+
return points
|
| 136 |
+
|
| 137 |
+
def refine_mask(self, image: np.ndarray, input_mask: np.ndarray, kernel_size: int = 21) -> np.ndarray:
|
| 138 |
+
"""
|
| 139 |
+
Refine an input mask using the SAM (Segment Anything Model) model.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
image: RGB image, shape (H, W, 3), values in [0, 255]
|
| 143 |
+
input_mask: Binary mask, shape (H, W), values in {0, 1}
|
| 144 |
+
kernel_size: Size of morphological kernel (default: 21)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Refined binary mask, shape (H, W), values in {0, 1}
|
| 148 |
+
"""
|
| 149 |
+
points = self.sample_points_from_mask(input_mask, max_points=128)
|
| 150 |
+
if len(points) == 0:
|
| 151 |
+
return input_mask
|
| 152 |
+
|
| 153 |
+
self.predictor.set_image(image)
|
| 154 |
+
masks_pred, _, _ = self.predictor.predict(
|
| 155 |
+
point_coords=points,
|
| 156 |
+
point_labels=np.ones(len(points)),
|
| 157 |
+
multimask_output=False
|
| 158 |
+
)
|
| 159 |
+
sam_mask = masks_pred[0]
|
| 160 |
+
|
| 161 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 162 |
+
expanded_input = cv2.dilate(input_mask.astype(np.uint8), kernel)
|
| 163 |
+
preserved_input = cv2.erode(input_mask.astype(np.uint8), kernel)
|
| 164 |
+
|
| 165 |
+
sam_mask = np.logical_and(expanded_input, sam_mask).astype(np.uint8)
|
| 166 |
+
sam_mask = np.logical_or(preserved_input, sam_mask).astype(np.uint8)
|
| 167 |
+
|
| 168 |
+
return sam_mask
|
utils/ui_utils.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from time import perf_counter
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
|
| 11 |
+
|
| 12 |
+
from utils.drag import bi_warp
|
| 13 |
+
from utils.refine_mask import SamMaskRefiner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
'clear_all', 'resize',
|
| 18 |
+
'visualize_user_drag', 'preview_out_image', 'inpaint',
|
| 19 |
+
'add_point', 'undo_point', 'clear_point',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# UI functions
|
| 23 |
+
def clear_all(length):
|
| 24 |
+
"""Reset UI by clearing all input images and parameters."""
|
| 25 |
+
return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 21, 2, "output/app", None)
|
| 26 |
+
|
| 27 |
+
def resize(canvas, gen_length, canvas_length):
|
| 28 |
+
"""Resize canvas while maintaining aspect ratio."""
|
| 29 |
+
if not canvas:
|
| 30 |
+
return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
|
| 31 |
+
|
| 32 |
+
image = process_canvas(canvas)[0]
|
| 33 |
+
aspect_ratio = image.shape[1] / image.shape[0]
|
| 34 |
+
is_landscape = aspect_ratio >= 1
|
| 35 |
+
|
| 36 |
+
new_dims = (
|
| 37 |
+
(gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape
|
| 38 |
+
else (round(gen_length * aspect_ratio / 8) * 8, gen_length)
|
| 39 |
+
)
|
| 40 |
+
canvas_dims = (
|
| 41 |
+
(canvas_length, round(canvas_length / aspect_ratio)) if is_landscape
|
| 42 |
+
else (round(canvas_length * aspect_ratio), canvas_length)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3
|
| 46 |
+
|
| 47 |
+
def process_canvas(canvas):
|
| 48 |
+
"""Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
|
| 49 |
+
image = canvas["image"].copy()
|
| 50 |
+
mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy()
|
| 51 |
+
return image, mask
|
| 52 |
+
|
| 53 |
+
# Point manipulation functions
|
| 54 |
+
def add_point(canvas, points, sam_ks, if_sam, output_path, evt: gr.SelectData):
|
| 55 |
+
"""Add selected point to points list and update image."""
|
| 56 |
+
if canvas is None:
|
| 57 |
+
return None
|
| 58 |
+
points.append(evt.index)
|
| 59 |
+
return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
|
| 60 |
+
|
| 61 |
+
def undo_point(canvas, points, sam_ks, if_sam, output_path):
|
| 62 |
+
"""Remove last point and update image."""
|
| 63 |
+
if canvas is None:
|
| 64 |
+
return None
|
| 65 |
+
if len(points) > 0:
|
| 66 |
+
points.pop()
|
| 67 |
+
return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
|
| 68 |
+
|
| 69 |
+
def clear_point(canvas, points, sam_ks, if_sam, output_path):
|
| 70 |
+
"""Clear all points and update image."""
|
| 71 |
+
if canvas is None:
|
| 72 |
+
return None
|
| 73 |
+
points.clear()
|
| 74 |
+
return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
|
| 75 |
+
|
| 76 |
+
# Visualization tools
|
| 77 |
+
def refine_mask(image, mask, kernel_size):
|
| 78 |
+
"""Refine mask using SAM model if available."""
|
| 79 |
+
global sam_refiner
|
| 80 |
+
try:
|
| 81 |
+
if 'sam_refiner' not in globals():
|
| 82 |
+
sam_refiner = SamMaskRefiner()
|
| 83 |
+
return sam_refiner.refine_mask(image, mask, kernel_size)
|
| 84 |
+
except ImportError:
|
| 85 |
+
gr.Warning("EfficientVit not installed. Please install with: pip install git+https://github.com/mit-han-lab/efficientvit.git")
|
| 86 |
+
return mask
|
| 87 |
+
except Exception as e:
|
| 88 |
+
gr.Warning(f"Error refining mask: {str(e)}")
|
| 89 |
+
return mask
|
| 90 |
+
|
| 91 |
+
def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
|
| 92 |
+
"""Visualize control points and motion vectors on the input image.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
canvas (dict): Gradio canvas containing image and mask
|
| 96 |
+
points (list): List of (x,y) coordinate pairs for control points
|
| 97 |
+
sam_ks (int): Kernel size for SAM mask refinement
|
| 98 |
+
if_sam (bool): Whether to use SAM refinement on mask
|
| 99 |
+
"""
|
| 100 |
+
if canvas is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
image, mask = process_canvas(canvas)
|
| 104 |
+
mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
|
| 105 |
+
|
| 106 |
+
# Apply colored mask overlay
|
| 107 |
+
result = image.copy()
|
| 108 |
+
result[mask == 1] = [255, 0, 0] # Red color
|
| 109 |
+
image = cv2.addWeighted(result, 0.3, image, 0.7, 0)
|
| 110 |
+
|
| 111 |
+
# Draw mask outline
|
| 112 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 113 |
+
cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
|
| 114 |
+
|
| 115 |
+
# Draw control points and motion vectors
|
| 116 |
+
for idx, point in enumerate(points, 1):
|
| 117 |
+
if idx % 2 == 0:
|
| 118 |
+
cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point
|
| 119 |
+
cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5)
|
| 120 |
+
else:
|
| 121 |
+
cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
|
| 122 |
+
prev_point = point
|
| 123 |
+
|
| 124 |
+
if output_path:
|
| 125 |
+
os.makedirs(output_path, exist_ok=True)
|
| 126 |
+
Image.fromarray(image).save(os.path.join(output_path, 'user_drag_i4p.png'))
|
| 127 |
+
return image
|
| 128 |
+
|
| 129 |
+
def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_path=None):
|
| 130 |
+
"""Preview warped image result and generate inpainting mask.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
canvas (dict): Gradio canvas containing the input image and mask
|
| 134 |
+
points (list): List of (x,y) coordinate pairs defining source and target positions for warping
|
| 135 |
+
sam_ks (int): Kernel size parameter for SAM mask refinement
|
| 136 |
+
inpaint_ks (int): Kernel size parameter for inpainting mask generation
|
| 137 |
+
if_sam (bool): Whether to use SAM model for mask refinement
|
| 138 |
+
output_path (str, optional): Directory path to save original image and metadata
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
tuple:
|
| 142 |
+
- ndarray: Warped image with grid pattern overlay on regions needing inpainting
|
| 143 |
+
- ndarray: Binary mask (255 for inpainting regions, 0 elsewhere)
|
| 144 |
+
- (None, None): If canvas is empty or fewer than 2 control points provided
|
| 145 |
+
"""
|
| 146 |
+
if canvas is None:
|
| 147 |
+
return None, None
|
| 148 |
+
|
| 149 |
+
image, mask = process_canvas(canvas)
|
| 150 |
+
if len(points) < 2:
|
| 151 |
+
return image, None
|
| 152 |
+
|
| 153 |
+
# ensure H, W divisible by 8 and longer edge 512
|
| 154 |
+
shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2])
|
| 155 |
+
size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
|
| 156 |
+
if not (shapes_valid and size_valid):
|
| 157 |
+
gr.Warning('Click Resize Image Button first.')
|
| 158 |
+
|
| 159 |
+
mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
|
| 160 |
+
|
| 161 |
+
if output_path:
|
| 162 |
+
os.makedirs(output_path, exist_ok=True)
|
| 163 |
+
Image.fromarray(image).save(os.path.join(output_path, 'original_image.png'))
|
| 164 |
+
metadata = {'mask': mask, 'points': points}
|
| 165 |
+
with open(os.path.join(output_path, 'meta_data_i4p.pkl'), 'wb') as f:
|
| 166 |
+
pickle.dump(metadata, f)
|
| 167 |
+
|
| 168 |
+
handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
|
| 169 |
+
image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
|
| 170 |
+
|
| 171 |
+
# Add grid pattern to highlight inpainting regions
|
| 172 |
+
background = np.ones_like(mask) * 255
|
| 173 |
+
background[::10] = background[:, ::10] = 0
|
| 174 |
+
image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
|
| 175 |
+
|
| 176 |
+
if output_path:
|
| 177 |
+
Image.fromarray(image).save(os.path.join(output_path, 'preview_image.png'))
|
| 178 |
+
|
| 179 |
+
return image, (inpaint_mask * 255).astype(np.uint8)
|
| 180 |
+
|
| 181 |
+
# Inpaint tools
|
| 182 |
+
def setup_pipeline(device='cuda', model_version='v1-5'):
|
| 183 |
+
"""Initialize optimized inpainting pipeline with specified model configuration."""
|
| 184 |
+
MODEL_CONFIGS = {
|
| 185 |
+
'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'),
|
| 186 |
+
'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl')
|
| 187 |
+
}
|
| 188 |
+
model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
|
| 189 |
+
|
| 190 |
+
pipe = AutoPipelineForInpainting.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
|
| 191 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
| 192 |
+
pipe.load_lora_weights(lora_id)
|
| 193 |
+
pipe.fuse_lora()
|
| 194 |
+
pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch.float16)
|
| 195 |
+
pipe = pipe.to(device)
|
| 196 |
+
|
| 197 |
+
# Pre-compute prompt embeddings during setup
|
| 198 |
+
if model_version == 'v1-5':
|
| 199 |
+
pipe.cached_prompt_embeds = pipe.encode_prompt(
|
| 200 |
+
'', device=device, num_images_per_prompt=1,
|
| 201 |
+
do_classifier_free_guidance=False)[0]
|
| 202 |
+
else:
|
| 203 |
+
pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt(
|
| 204 |
+
'', device=device, num_images_per_prompt=1,
|
| 205 |
+
do_classifier_free_guidance=False)[0::2]
|
| 206 |
+
|
| 207 |
+
return pipe
|
| 208 |
+
|
| 209 |
+
pipe = setup_pipeline(model_version='v1-5')
|
| 210 |
+
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
|
| 211 |
+
|
| 212 |
+
def inpaint(image, inpaint_mask):
|
| 213 |
+
"""Perform efficient inpainting on masked regions using Stable Diffusion.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
image (ndarray): Input RGB image array (warped preview image)
|
| 217 |
+
inpaint_mask (ndarray): Binary mask array where 255 indicates regions to inpaint
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
ndarray: Inpainted image with masked regions filled in
|
| 221 |
+
"""
|
| 222 |
+
if image is None:
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
if inpaint_mask is None:
|
| 226 |
+
return image
|
| 227 |
+
|
| 228 |
+
start = perf_counter()
|
| 229 |
+
pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
|
| 230 |
+
inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
|
| 231 |
+
|
| 232 |
+
# Convert inputs to PIL
|
| 233 |
+
image_pil = Image.fromarray(image)
|
| 234 |
+
inpaint_mask_pil = Image.fromarray(inpaint_mask)
|
| 235 |
+
|
| 236 |
+
width, height = inpaint_mask_pil.size
|
| 237 |
+
if width % 8 != 0 or height % 8 != 0:
|
| 238 |
+
width, height = round(width / 8) * 8, round(height / 8) * 8
|
| 239 |
+
image_pil = image_pil.resize((width, height))
|
| 240 |
+
image = np.array(image_pil)
|
| 241 |
+
inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST)
|
| 242 |
+
inpaint_mask = np.array(inpaint_mask_pil)
|
| 243 |
+
|
| 244 |
+
# Common pipeline parameters
|
| 245 |
+
common_params = {
|
| 246 |
+
'image': image_pil,
|
| 247 |
+
'mask_image': inpaint_mask_pil,
|
| 248 |
+
'height': height,
|
| 249 |
+
'width': width,
|
| 250 |
+
'guidance_scale': 1.0,
|
| 251 |
+
'num_inference_steps': 8,
|
| 252 |
+
'strength': inpaint_strength,
|
| 253 |
+
'output_type': 'np'
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
# Run pipeline
|
| 257 |
+
if pipe_id == 'v1-5':
|
| 258 |
+
inpainted = pipe(
|
| 259 |
+
prompt_embeds=pipe.cached_prompt_embeds,
|
| 260 |
+
**common_params
|
| 261 |
+
).images[0]
|
| 262 |
+
else:
|
| 263 |
+
inpainted = pipe(
|
| 264 |
+
prompt_embeds=pipe.cached_prompt_embeds,
|
| 265 |
+
pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
|
| 266 |
+
**common_params
|
| 267 |
+
).images[0]
|
| 268 |
+
|
| 269 |
+
# Post-process results
|
| 270 |
+
inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
|
| 271 |
+
return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)
|