LuJingyi-John commited on
Commit
6678b47
·
1 Parent(s): 75fac02

Add Inpaint4Drag application with all components

Browse files
Files changed (6) hide show
  1. .gitignore +28 -0
  2. app.py +201 -0
  3. requirements.txt +21 -0
  4. utils/drag.py +297 -0
  5. utils/refine_mask.py +168 -0
  6. utils/ui_utils.py +271 -0
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output/
2
+ checkpoints/
3
+ drag_data/
4
+ webpage/
5
+
6
+ play.py
7
+
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+ *.so
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.ui_utils import *
3
+
4
+ CANVAS_SIZE = 400
5
+ DEFAULT_GEN_SIZE = 512
6
+
7
+ def create_interface():
8
+ with gr.Blocks() as app:
9
+ # State variables
10
+ state = {
11
+ 'canvas_size': gr.Number(value=CANVAS_SIZE, visible=False, precision=0),
12
+ 'gen_size': gr.Number(value=DEFAULT_GEN_SIZE, visible=False, precision=0),
13
+ 'points_list': gr.State(value=[]),
14
+ 'inpaint_mask': gr.State(value=None)
15
+ }
16
+
17
+ with gr.Tab(label='Inpaint4Drag'):
18
+ with gr.Row():
19
+ # Draw Region Column
20
+ with gr.Column():
21
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">1. Draw Regions</p>""")
22
+ canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
23
+ with gr.Row():
24
+ fit_btn = gr.Button("Resize Image")
25
+ if_sam_box = gr.Checkbox(label='Refine mask (SAM)')
26
+
27
+ # Control Points Column
28
+ with gr.Column():
29
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">2. Control Points</p>""")
30
+ input_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=True)
31
+ with gr.Row():
32
+ undo_btn = gr.Button("Undo Point")
33
+ clear_btn = gr.Button("Clear Points")
34
+
35
+ # Results Column
36
+ with gr.Column():
37
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
38
+ output_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=False)
39
+ with gr.Row():
40
+ run_btn = gr.Button("Inpaint")
41
+ reset_btn = gr.Button("Reset All")
42
+
43
+ # Output Settings
44
+ with gr.Row("Generation Parameters"):
45
+ sam_ks = gr.Slider(minimum=11, maximum=51, value=21, step=2, label='How much to refine mask with SAM', interactive=True)
46
+ inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
47
+ output_path = gr.Textbox(value='output/app', label="Output path")
48
+
49
+ setup_events(
50
+ components={
51
+ 'canvas': canvas,
52
+ 'input_img': input_img,
53
+ 'output_img': output_img,
54
+ 'output_path': output_path,
55
+ 'if_sam_box': if_sam_box,
56
+ 'sam_ks': sam_ks,
57
+ 'inpaint_ks': inpaint_ks,
58
+ },
59
+ state=state,
60
+ buttons={
61
+ 'fit': fit_btn,
62
+ 'undo': undo_btn,
63
+ 'clear': clear_btn,
64
+ 'run': run_btn,
65
+ 'reset': reset_btn
66
+ }
67
+ )
68
+
69
+ return app
70
+
71
+ def setup_events(components, state, buttons):
72
+ # Reset and clear events
73
+ def setup_reset_events():
74
+ buttons['reset'].click(
75
+ clear_all,
76
+ [state['canvas_size']],
77
+ [components['canvas'], components['input_img'], components['output_img'],
78
+ state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
79
+ )
80
+
81
+ components['canvas'].clear(
82
+ clear_all,
83
+ [state['canvas_size']],
84
+ [components['canvas'], components['input_img'], components['output_img'],
85
+ state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
86
+ )
87
+
88
+ # Image manipulation events
89
+ def setup_image_events():
90
+ buttons['fit'].click(
91
+ clear_point,
92
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
93
+ [components['input_img']]
94
+ ).then(
95
+ resize,
96
+ [components['canvas'], state['gen_size'], state['canvas_size']],
97
+ [components['canvas'], components['input_img'], components['output_img']]
98
+ )
99
+
100
+ # Canvas interaction events
101
+ def setup_canvas_events():
102
+ components['canvas'].edit(
103
+ visualize_user_drag,
104
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
105
+ [components['input_img']]
106
+ ).then(
107
+ preview_out_image,
108
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
109
+ [components['output_img'], state['inpaint_mask']]
110
+ )
111
+
112
+ components['if_sam_box'].change(
113
+ visualize_user_drag,
114
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
115
+ [components['input_img']]
116
+ ).then(
117
+ preview_out_image,
118
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
119
+ [components['output_img'], state['inpaint_mask']]
120
+ )
121
+
122
+ components['sam_ks'].change(
123
+ visualize_user_drag,
124
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
125
+ [components['input_img']]
126
+ ).then(
127
+ preview_out_image,
128
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
129
+ [components['output_img'], state['inpaint_mask']]
130
+ )
131
+
132
+ components['inpaint_ks'].change(
133
+ visualize_user_drag,
134
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
135
+ [components['input_img']]
136
+ ).then(
137
+ preview_out_image,
138
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
139
+ [components['output_img'], state['inpaint_mask']]
140
+ )
141
+
142
+ # Input image events
143
+ def setup_input_events():
144
+ components['input_img'].select(
145
+ add_point,
146
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
147
+ [components['input_img']]
148
+ ).then(
149
+ preview_out_image,
150
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
151
+ [components['output_img'], state['inpaint_mask']]
152
+ )
153
+
154
+ # Point manipulation events
155
+ def setup_point_events():
156
+ buttons['undo'].click(
157
+ undo_point,
158
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
159
+ [components['input_img']]
160
+ ).then(
161
+ preview_out_image,
162
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
163
+ [components['output_img'], state['inpaint_mask']]
164
+ )
165
+
166
+ buttons['clear'].click(
167
+ clear_point,
168
+ [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
169
+ [components['input_img']]
170
+ ).then(
171
+ preview_out_image,
172
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
173
+ [components['output_img'], state['inpaint_mask']]
174
+ )
175
+
176
+ # Processing events
177
+ def setup_processing_events():
178
+ buttons['run'].click(
179
+ preview_out_image,
180
+ [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
181
+ [components['output_img'], state['inpaint_mask']]
182
+ ).then(
183
+ inpaint,
184
+ [components['output_img'], state['inpaint_mask']],
185
+ [components['output_img']]
186
+ )
187
+
188
+ # Setup all events
189
+ setup_reset_events()
190
+ setup_image_events()
191
+ setup_canvas_events()
192
+ setup_input_events()
193
+ setup_point_events()
194
+ setup_processing_events()
195
+
196
+ def main():
197
+ app = create_interface()
198
+ app.queue().launch(share=True, debug=True)
199
+
200
+ if __name__ == '__main__':
201
+ main()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML Libraries
2
+ torch
3
+ torchvision
4
+ transformers
5
+ diffusers
6
+ accelerate
7
+ peft
8
+ xformers
9
+
10
+ # UI and Image Processing
11
+ gradio==3.47.1
12
+ opencv-python==4.8.0.76
13
+ Pillow
14
+ numpy
15
+
16
+ # Evaluation (Optional)
17
+ lpips
18
+ gdown
19
+
20
+ # EfficientViT-SAM (Optional)
21
+ git+https://github.com/mit-han-lab/efficientvit.git
utils/drag.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ from typing import Union
5
+
6
+ def contour_to_points_and_mask(contour: np.ndarray, image_shape: tuple) -> tuple[np.ndarray, np.ndarray]:
7
+ """Convert a contour to a set of points and binary mask.
8
+
9
+ This function takes a contour and creates both a binary mask and a list of points
10
+ that lie within the contour. The points are represented in (x, y) coordinates.
11
+
12
+ Args:
13
+ contour (np.ndarray): Input contour of shape (N, 2) or (N, 1, 2) where N is
14
+ the number of points. Each point should be in (x, y) format.
15
+ image_shape (tuple): Shape of the output mask as (height, width).
16
+
17
+ Returns:
18
+ tuple:
19
+ - np.ndarray: Array of points in (x, y) format with shape (M, 2),
20
+ where M is the number of points inside the contour.
21
+ Returns empty array of shape (0, 2) if contour is empty.
22
+ - np.ndarray: Binary mask of shape image_shape where pixels inside
23
+ the contour are 255 and outside are 0.
24
+ """
25
+ if len(contour) == 0:
26
+ return np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)
27
+
28
+ # Create empty mask and fill the contour in the mask
29
+ mask = np.zeros(image_shape, dtype=np.uint8)
30
+ cv2.drawContours(mask, [contour.reshape(-1, 1, 2)], -1, 255, cv2.FILLED)
31
+
32
+ # Get points inside contour (y, x) and convert to (x, y)
33
+ points = np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]]
34
+
35
+ # Return empty array if no points found
36
+ if len(points) == 0:
37
+ points = np.zeros((0, 2), dtype=np.int32)
38
+
39
+ return points, mask
40
+
41
+ def find_control_points(
42
+ region_points: torch.Tensor,
43
+ source_control_points: torch.Tensor,
44
+ target_control_points: torch.Tensor,
45
+ distance_threshold: float = 1e-6
46
+ ) -> tuple[torch.Tensor, torch.Tensor]:
47
+ """Find control points that match points within a region.
48
+
49
+ This function identifies which control points lie within or very close to
50
+ the specified region points. It matches source control points to region points
51
+ and returns both source and corresponding target control points that satisfy
52
+ the distance threshold criterion.
53
+
54
+ Args:
55
+ region_points (torch.Tensor): Points defining a region, shape (N, 2).
56
+ Each point is in (x, y) format.
57
+ source_control_points (torch.Tensor): Source control points, shape (M, 2).
58
+ Each point is in (x, y) format.
59
+ target_control_points (torch.Tensor): Target control points, shape (M, 2).
60
+ Must have same first dimension as source_control_points.
61
+ distance_threshold (float, optional): Maximum distance for a point to be
62
+ considered matching. Defaults to 1e-6.
63
+
64
+ Returns:
65
+ tuple[torch.Tensor, torch.Tensor]:
66
+ - Matched source control points, shape (K, 2) where K ≤ M
67
+ - Corresponding target control points, shape (K, 2)
68
+ If no matches found or inputs empty, returns empty tensors of shape (0, 2)
69
+ """
70
+ # Handle empty input cases
71
+ if len(region_points) == 0 or len(source_control_points) == 0:
72
+ return (
73
+ torch.zeros((0, 2), device=source_control_points.device),
74
+ torch.zeros((0, 2), device=target_control_points.device)
75
+ )
76
+
77
+ # Calculate pairwise distances between source control points and region points
78
+ distances = torch.cdist(source_control_points, region_points)
79
+
80
+ # Find points that are within threshold distance of any region point
81
+ min_distances = distances.min(dim=1)[0]
82
+ matching_indices = min_distances < distance_threshold
83
+
84
+ # Return matched pairs of control points
85
+ return source_control_points[matching_indices], target_control_points[matching_indices]
86
+
87
+ def interpolate_points_with_weighted_directions(
88
+ points: torch.Tensor,
89
+ reference_points: torch.Tensor,
90
+ direction_vectors: torch.Tensor,
91
+ max_reference_points: int = 100,
92
+ num_nearest_neighbors: int = 4,
93
+ eps: float = 1e-6
94
+ ) -> torch.Tensor:
95
+ """Interpolate points based on weighted directions from nearest reference points.
96
+
97
+ This function moves each point by a weighted combination of direction vectors.
98
+ The weights are determined by the inverse distances to the nearest reference points.
99
+ If there are too many reference points, they are subsampled for efficiency.
100
+
101
+ Args:
102
+ points (torch.Tensor): Points to interpolate, shape (N, 2) in (x, y) format
103
+ reference_points (torch.Tensor): Reference point locations, shape (M, 2)
104
+ direction_vectors (torch.Tensor): Direction vectors for each reference point,
105
+ shape (M, 2), must match reference_points first dimension
106
+ max_reference_points (int, optional): Maximum number of reference points to use.
107
+ If exceeded, points are subsampled. Defaults to 100.
108
+ num_nearest_neighbors (int, optional): Number of nearest neighbors to consider
109
+ for interpolation. Defaults to 4.
110
+ eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
111
+
112
+ Returns:
113
+ torch.Tensor: Interpolated points with shape (N, 2). If input points or
114
+ references are empty, returns the input points unchanged.
115
+ """
116
+ # Handle empty input cases
117
+ if len(points) == 0 or len(reference_points) == 0:
118
+ return points
119
+
120
+ # Handle single reference point case
121
+ if len(reference_points) == 1:
122
+ return points + direction_vectors
123
+
124
+ # Subsample reference points if too many
125
+ if len(reference_points) > max_reference_points:
126
+ indices = torch.linspace(0, len(reference_points)-1, max_reference_points).long()
127
+ reference_points = reference_points[indices]
128
+ direction_vectors = direction_vectors[indices]
129
+
130
+ # Calculate distances to all reference points
131
+ distances = torch.cdist(points, reference_points)
132
+
133
+ # Find k nearest neighbors (k = min(num_nearest_neighbors, num_references))
134
+ k = min(num_nearest_neighbors, len(reference_points))
135
+ topk_distances, neighbor_indices = torch.topk(
136
+ distances,
137
+ k=k,
138
+ dim=1,
139
+ largest=False
140
+ )
141
+
142
+ # Calculate weights based on inverse distances
143
+ weights = 1.0 / (topk_distances + eps)
144
+ weights = weights / weights.sum(dim=1, keepdim=True)
145
+
146
+ # Get directions for nearest neighbors and compute weighted average
147
+ neighbor_directions = direction_vectors[neighbor_indices]
148
+ weighted_directions = (weights.unsqueeze(-1) * neighbor_directions).sum(dim=1)
149
+
150
+ # Apply weighted directions and round to nearest integer
151
+ interpolated_points = (points + weighted_directions).round().float()
152
+
153
+ return interpolated_points
154
+
155
+ def get_points_within_image_bounds(
156
+ points: torch.Tensor,
157
+ image_shape: tuple[int, int]
158
+ ) -> torch.Tensor:
159
+ """Create a boolean mask for points that lie within image boundaries.
160
+
161
+ Identifies which points from the input tensor fall within valid image coordinates.
162
+ Points are assumed to be in (x, y) format, while image_shape is in (height, width) format.
163
+
164
+ Args:
165
+ points (torch.Tensor): Points to check, shape (N, 2) in (x, y) format.
166
+ x coordinates correspond to width/columns
167
+ y coordinates correspond to height/rows
168
+ image_shape (tuple[int, int]): Image dimensions as (height, width).
169
+
170
+ Returns:
171
+ torch.Tensor: Boolean mask of shape (N,) where True indicates the point
172
+ is within bounds. Returns empty tensor of shape (0,) if input is empty.
173
+ """
174
+ # Handle empty input case
175
+ if len(points) == 0:
176
+ return torch.zeros(0, dtype=torch.bool, device=points.device)
177
+
178
+ # Unpack image dimensions
179
+ height, width = image_shape
180
+
181
+ # Check both x and y coordinates are within bounds
182
+ x_in_bounds = (points[:, 0] >= 0) & (points[:, 0] < width)
183
+ y_in_bounds = (points[:, 1] >= 0) & (points[:, 1] < height)
184
+
185
+ # Combine conditions
186
+ valid_points_mask = x_in_bounds & y_in_bounds
187
+
188
+ return valid_points_mask
189
+
190
+ def bi_warp(
191
+ region_mask: np.ndarray,
192
+ control_points: Union[np.ndarray, torch.Tensor],
193
+ kernel_size: int = 5
194
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
195
+ """Generate corresponding source/target points and inpainting mask for masked regions.
196
+
197
+ Args:
198
+ region_mask: Binary mask defining regions of interest (2D array with 0s and 1s)
199
+ control_points: Alternating source and target control points. Shape (N*2, 2)
200
+ kernel_size: Controls dilation kernel size. Must be odd number or 0.
201
+ Contour thickness will be (kernel_size-1)*2 (default: 5)
202
+ Set to 0 for no contour drawing and no dilation.
203
+
204
+ Returns:
205
+ tuple containing:
206
+ - Source points (M, 2)
207
+ - Target points (M, 2)
208
+ - Inpainting mask combined with target contour mask
209
+ """
210
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
211
+ image_shape = region_mask.shape
212
+
213
+ # Ensure kernel_size is odd or 0
214
+ kernel_size = max(0, kernel_size)
215
+ if kernel_size > 0 and kernel_size % 2 == 0:
216
+ kernel_size += 1
217
+
218
+ # 1. Initialize tensors and masks
219
+ control_points = torch.tensor(control_points, dtype=torch.float32, device=device) if not isinstance(control_points, torch.Tensor) else control_points
220
+ source_control_points = control_points[0:-1:2]
221
+ target_control_points = control_points[1::2]
222
+
223
+ combined_source_mask = np.zeros(image_shape, dtype=np.uint8)
224
+ combined_target_mask = np.zeros(image_shape, dtype=np.uint8)
225
+ region_mask_binary = np.where(region_mask > 0, 1, 0).astype(np.uint8)
226
+ contour_mask = np.zeros(image_shape, dtype=np.uint8)
227
+
228
+ # 2. Process regions
229
+ contours = cv2.findContours(region_mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
230
+ all_source_points = []
231
+ all_target_points = []
232
+
233
+ for contour in contours:
234
+ if len(contour) == 0:
235
+ continue
236
+
237
+ # 3. Get source region points and mask
238
+ source_contour = torch.from_numpy(contour[:, 0, :]).float().to(device)
239
+ source_region_points, source_mask = contour_to_points_and_mask(contour[:, 0, :], image_shape)
240
+ source_mask = (source_mask > 0).astype(np.uint8)
241
+
242
+ if len(source_region_points) == 0:
243
+ continue
244
+
245
+ source_region_points = torch.from_numpy(source_region_points).float().to(device)
246
+
247
+ # 4. Transform points
248
+ source, target = find_control_points(source_region_points, source_control_points, target_control_points)
249
+ if len(source) == 0:
250
+ continue
251
+
252
+ directions = target - source
253
+ target_contour = interpolate_points_with_weighted_directions(source_contour, source, directions)
254
+ interpolated_target = interpolate_points_with_weighted_directions(source_region_points, source, directions)
255
+
256
+ # 5. Get target region points and mask
257
+ target_region_points, target_mask = contour_to_points_and_mask(target_contour.cpu().int().numpy(), image_shape)
258
+ target_mask = (target_mask > 0).astype(np.uint8)
259
+
260
+ if len(target_region_points) == 0:
261
+ continue
262
+
263
+ # Draw target contour
264
+ target_contour_np = target_contour.cpu().int().numpy()
265
+ if kernel_size > 0:
266
+ cv2.drawContours(contour_mask, [target_contour_np], -1, 1, kernel_size)
267
+
268
+ target_region = torch.from_numpy(target_region_points).float().to(device)
269
+
270
+ # 6. Apply reverse transformation
271
+ back_directions = source_region_points - interpolated_target
272
+ interpolated_source = interpolate_points_with_weighted_directions(target_region, interpolated_target, back_directions)
273
+
274
+ # 7. Filter valid points
275
+ valid_mask = get_points_within_image_bounds(interpolated_source, image_shape)
276
+ if valid_mask.any():
277
+ all_source_points.append(interpolated_source[valid_mask])
278
+ all_target_points.append(target_region[valid_mask])
279
+ combined_source_mask = np.logical_or(combined_source_mask, source_mask).astype(np.uint8)
280
+ combined_target_mask = np.logical_or(combined_target_mask, target_mask).astype(np.uint8)
281
+
282
+ # 8. Handle empty case
283
+ if not all_source_points:
284
+ return np.zeros((0, 2), dtype=np.int32), np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)
285
+
286
+ # 9. Finalize outputs
287
+ final_source = torch.cat(all_source_points).cpu().numpy().astype(np.int32)
288
+ final_target = torch.cat(all_target_points).cpu().numpy().astype(np.int32)
289
+
290
+ # Create and combine masks
291
+ inpaint_mask = np.logical_and(combined_source_mask, np.logical_not(combined_target_mask)).astype(np.uint8)
292
+ if kernel_size > 0:
293
+ kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
294
+ inpaint_mask = cv2.dilate(inpaint_mask, kernel)
295
+ final_mask = np.logical_or(inpaint_mask, contour_mask).astype(np.uint8)
296
+
297
+ return final_source, final_target, final_mask
utils/refine_mask.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ from typing import Optional
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def download_model(checkpoint_path: str, model_name: str = "efficientvit_sam_l0.pt") -> str:
12
+ """
13
+ Download the model checkpoint if not found locally.
14
+
15
+ Args:
16
+ checkpoint_path: Local path where model should be saved
17
+ model_name: Name of the model file to download
18
+
19
+ Returns:
20
+ str: Path to the downloaded checkpoint
21
+ """
22
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
23
+
24
+ base_url = "https://huggingface.co/mit-han-lab/efficientvit-sam/resolve/main"
25
+ model_url = f"{base_url}/{model_name}"
26
+
27
+ try:
28
+ print(f"Downloading model from {model_url}...")
29
+ urllib.request.urlretrieve(model_url, checkpoint_path)
30
+ print(f"Model successfully downloaded to {checkpoint_path}")
31
+ return checkpoint_path
32
+ except Exception as e:
33
+ raise RuntimeError(f"Failed to download model: {str(e)}")
34
+
35
+
36
+ class SamMaskRefiner(nn.Module):
37
+ CHECKPOINT_DIR = 'checkpoints'
38
+ MODEL_CONFIGS = {
39
+ 'l0': 'efficientvit_sam_l0.pt',
40
+ 'l1': 'efficientvit_sam_l1.pt',
41
+ 'l2': 'efficientvit_sam_l2.pt'
42
+ }
43
+
44
+ def __init__(self, model_name: str = 'l0') -> None:
45
+ """
46
+ Initialize SAM predictor with specified model version.
47
+
48
+ Args:
49
+ model_name: Model version to use ('l0', 'l1', or 'l2'). Defaults to 'l0'.
50
+
51
+ Raises:
52
+ ValueError: If invalid model_name is provided
53
+ RuntimeError: If model loading fails after download attempt
54
+ """
55
+ super().__init__()
56
+
57
+ if model_name not in self.MODEL_CONFIGS:
58
+ raise ValueError(f"Invalid model_name. Choose from: {list(self.MODEL_CONFIGS.keys())}")
59
+
60
+ model_filename = self.MODEL_CONFIGS[model_name]
61
+ checkpoint_path = os.path.join(self.CHECKPOINT_DIR, model_filename)
62
+
63
+ try:
64
+ from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
65
+ from efficientvit.sam_model_zoo import create_efficientvit_sam_model
66
+ except ImportError:
67
+ raise ImportError(
68
+ "Failed to import EfficientViT modules. Please ensure the package is installed:\n"
69
+ "pip install git+https://github.com/mit-han-lab/efficientvit.git"
70
+ )
71
+
72
+ if not os.path.exists(checkpoint_path):
73
+ print(f"Checkpoint not found at {checkpoint_path}. Attempting to download...")
74
+ checkpoint_path = download_model(checkpoint_path, model_filename)
75
+
76
+ try:
77
+ model_type = f'efficientvit-sam-{model_name}'
78
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
+ self.model = create_efficientvit_sam_model(model_type, True, checkpoint_path).eval()
80
+ self.model = self.model.requires_grad_(False).to(device)
81
+ self.predictor = EfficientViTSamPredictor(self.model)
82
+ print(f"\033[92mEfficientViT-SAM model loaded from: {checkpoint_path}\033[0m")
83
+ except Exception as e:
84
+ raise RuntimeError(f"Failed to load model: {str(e)}")
85
+
86
+ def sample_points_from_mask(self, mask: np.ndarray, max_points: int = 128) -> np.ndarray:
87
+ """
88
+ Sample points uniformly from masked regions.
89
+
90
+ Args:
91
+ mask: Binary mask array of shape (H, W) with 0-1 values.
92
+ max_points: Maximum number of points to sample.
93
+
94
+ Returns:
95
+ np.ndarray: Array of shape (N, 2) containing [x,y] coordinates.
96
+ """
97
+ y_indices, x_indices = np.where(mask > 0.5)
98
+ total_points = len(y_indices)
99
+
100
+ if total_points <= max_points:
101
+ return np.stack([x_indices, y_indices], axis=1)
102
+
103
+ y_min, y_max = y_indices.min(), y_indices.max()
104
+ x_min, x_max = x_indices.min(), x_indices.max()
105
+
106
+ aspect_ratio = (x_max - x_min) / max(y_max - y_min, 1)
107
+ ny = int(np.sqrt(max_points / aspect_ratio))
108
+ nx = int(ny * aspect_ratio)
109
+
110
+ x_bins = np.linspace(x_min, x_max + 1, nx + 1, dtype=np.int32)
111
+ y_bins = np.linspace(y_min, y_max + 1, ny + 1, dtype=np.int32)
112
+
113
+ x_dig = np.digitize(x_indices, x_bins) - 1
114
+ y_dig = np.digitize(y_indices, y_bins) - 1
115
+ bin_indices = y_dig * nx + x_dig
116
+ unique_bins = np.unique(bin_indices)
117
+
118
+ points = []
119
+ for idx in unique_bins:
120
+ bin_y = idx // nx
121
+ bin_x = idx % nx
122
+ mask = (y_dig == bin_y) & (x_dig == bin_x)
123
+
124
+ if np.any(mask):
125
+ px = int(np.mean(x_indices[mask]))
126
+ py = int(np.mean(y_indices[mask]))
127
+ points.append([px, py])
128
+
129
+ points = np.array(points)
130
+
131
+ if len(points) > max_points:
132
+ indices = np.linspace(0, len(points) - 1, max_points, dtype=int)
133
+ points = points[indices]
134
+
135
+ return points
136
+
137
+ def refine_mask(self, image: np.ndarray, input_mask: np.ndarray, kernel_size: int = 21) -> np.ndarray:
138
+ """
139
+ Refine an input mask using the SAM (Segment Anything Model) model.
140
+
141
+ Args:
142
+ image: RGB image, shape (H, W, 3), values in [0, 255]
143
+ input_mask: Binary mask, shape (H, W), values in {0, 1}
144
+ kernel_size: Size of morphological kernel (default: 21)
145
+
146
+ Returns:
147
+ Refined binary mask, shape (H, W), values in {0, 1}
148
+ """
149
+ points = self.sample_points_from_mask(input_mask, max_points=128)
150
+ if len(points) == 0:
151
+ return input_mask
152
+
153
+ self.predictor.set_image(image)
154
+ masks_pred, _, _ = self.predictor.predict(
155
+ point_coords=points,
156
+ point_labels=np.ones(len(points)),
157
+ multimask_output=False
158
+ )
159
+ sam_mask = masks_pred[0]
160
+
161
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
162
+ expanded_input = cv2.dilate(input_mask.astype(np.uint8), kernel)
163
+ preserved_input = cv2.erode(input_mask.astype(np.uint8), kernel)
164
+
165
+ sam_mask = np.logical_and(expanded_input, sam_mask).astype(np.uint8)
166
+ sam_mask = np.logical_or(preserved_input, sam_mask).astype(np.uint8)
167
+
168
+ return sam_mask
utils/ui_utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from time import perf_counter
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
11
+
12
+ from utils.drag import bi_warp
13
+ from utils.refine_mask import SamMaskRefiner
14
+
15
+
16
+ __all__ = [
17
+ 'clear_all', 'resize',
18
+ 'visualize_user_drag', 'preview_out_image', 'inpaint',
19
+ 'add_point', 'undo_point', 'clear_point',
20
+ ]
21
+
22
+ # UI functions
23
+ def clear_all(length):
24
+ """Reset UI by clearing all input images and parameters."""
25
+ return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 21, 2, "output/app", None)
26
+
27
+ def resize(canvas, gen_length, canvas_length):
28
+ """Resize canvas while maintaining aspect ratio."""
29
+ if not canvas:
30
+ return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
31
+
32
+ image = process_canvas(canvas)[0]
33
+ aspect_ratio = image.shape[1] / image.shape[0]
34
+ is_landscape = aspect_ratio >= 1
35
+
36
+ new_dims = (
37
+ (gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape
38
+ else (round(gen_length * aspect_ratio / 8) * 8, gen_length)
39
+ )
40
+ canvas_dims = (
41
+ (canvas_length, round(canvas_length / aspect_ratio)) if is_landscape
42
+ else (round(canvas_length * aspect_ratio), canvas_length)
43
+ )
44
+
45
+ return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3
46
+
47
+ def process_canvas(canvas):
48
+ """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
49
+ image = canvas["image"].copy()
50
+ mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy()
51
+ return image, mask
52
+
53
+ # Point manipulation functions
54
+ def add_point(canvas, points, sam_ks, if_sam, output_path, evt: gr.SelectData):
55
+ """Add selected point to points list and update image."""
56
+ if canvas is None:
57
+ return None
58
+ points.append(evt.index)
59
+ return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
60
+
61
+ def undo_point(canvas, points, sam_ks, if_sam, output_path):
62
+ """Remove last point and update image."""
63
+ if canvas is None:
64
+ return None
65
+ if len(points) > 0:
66
+ points.pop()
67
+ return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
68
+
69
+ def clear_point(canvas, points, sam_ks, if_sam, output_path):
70
+ """Clear all points and update image."""
71
+ if canvas is None:
72
+ return None
73
+ points.clear()
74
+ return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
75
+
76
+ # Visualization tools
77
+ def refine_mask(image, mask, kernel_size):
78
+ """Refine mask using SAM model if available."""
79
+ global sam_refiner
80
+ try:
81
+ if 'sam_refiner' not in globals():
82
+ sam_refiner = SamMaskRefiner()
83
+ return sam_refiner.refine_mask(image, mask, kernel_size)
84
+ except ImportError:
85
+ gr.Warning("EfficientVit not installed. Please install with: pip install git+https://github.com/mit-han-lab/efficientvit.git")
86
+ return mask
87
+ except Exception as e:
88
+ gr.Warning(f"Error refining mask: {str(e)}")
89
+ return mask
90
+
91
+ def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
92
+ """Visualize control points and motion vectors on the input image.
93
+
94
+ Args:
95
+ canvas (dict): Gradio canvas containing image and mask
96
+ points (list): List of (x,y) coordinate pairs for control points
97
+ sam_ks (int): Kernel size for SAM mask refinement
98
+ if_sam (bool): Whether to use SAM refinement on mask
99
+ """
100
+ if canvas is None:
101
+ return None
102
+
103
+ image, mask = process_canvas(canvas)
104
+ mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
105
+
106
+ # Apply colored mask overlay
107
+ result = image.copy()
108
+ result[mask == 1] = [255, 0, 0] # Red color
109
+ image = cv2.addWeighted(result, 0.3, image, 0.7, 0)
110
+
111
+ # Draw mask outline
112
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
113
+ cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
114
+
115
+ # Draw control points and motion vectors
116
+ for idx, point in enumerate(points, 1):
117
+ if idx % 2 == 0:
118
+ cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point
119
+ cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5)
120
+ else:
121
+ cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
122
+ prev_point = point
123
+
124
+ if output_path:
125
+ os.makedirs(output_path, exist_ok=True)
126
+ Image.fromarray(image).save(os.path.join(output_path, 'user_drag_i4p.png'))
127
+ return image
128
+
129
+ def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_path=None):
130
+ """Preview warped image result and generate inpainting mask.
131
+
132
+ Args:
133
+ canvas (dict): Gradio canvas containing the input image and mask
134
+ points (list): List of (x,y) coordinate pairs defining source and target positions for warping
135
+ sam_ks (int): Kernel size parameter for SAM mask refinement
136
+ inpaint_ks (int): Kernel size parameter for inpainting mask generation
137
+ if_sam (bool): Whether to use SAM model for mask refinement
138
+ output_path (str, optional): Directory path to save original image and metadata
139
+
140
+ Returns:
141
+ tuple:
142
+ - ndarray: Warped image with grid pattern overlay on regions needing inpainting
143
+ - ndarray: Binary mask (255 for inpainting regions, 0 elsewhere)
144
+ - (None, None): If canvas is empty or fewer than 2 control points provided
145
+ """
146
+ if canvas is None:
147
+ return None, None
148
+
149
+ image, mask = process_canvas(canvas)
150
+ if len(points) < 2:
151
+ return image, None
152
+
153
+ # ensure H, W divisible by 8 and longer edge 512
154
+ shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2])
155
+ size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
156
+ if not (shapes_valid and size_valid):
157
+ gr.Warning('Click Resize Image Button first.')
158
+
159
+ mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
160
+
161
+ if output_path:
162
+ os.makedirs(output_path, exist_ok=True)
163
+ Image.fromarray(image).save(os.path.join(output_path, 'original_image.png'))
164
+ metadata = {'mask': mask, 'points': points}
165
+ with open(os.path.join(output_path, 'meta_data_i4p.pkl'), 'wb') as f:
166
+ pickle.dump(metadata, f)
167
+
168
+ handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
169
+ image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
170
+
171
+ # Add grid pattern to highlight inpainting regions
172
+ background = np.ones_like(mask) * 255
173
+ background[::10] = background[:, ::10] = 0
174
+ image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
175
+
176
+ if output_path:
177
+ Image.fromarray(image).save(os.path.join(output_path, 'preview_image.png'))
178
+
179
+ return image, (inpaint_mask * 255).astype(np.uint8)
180
+
181
+ # Inpaint tools
182
+ def setup_pipeline(device='cuda', model_version='v1-5'):
183
+ """Initialize optimized inpainting pipeline with specified model configuration."""
184
+ MODEL_CONFIGS = {
185
+ 'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'),
186
+ 'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl')
187
+ }
188
+ model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
189
+
190
+ pipe = AutoPipelineForInpainting.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
191
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
192
+ pipe.load_lora_weights(lora_id)
193
+ pipe.fuse_lora()
194
+ pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch.float16)
195
+ pipe = pipe.to(device)
196
+
197
+ # Pre-compute prompt embeddings during setup
198
+ if model_version == 'v1-5':
199
+ pipe.cached_prompt_embeds = pipe.encode_prompt(
200
+ '', device=device, num_images_per_prompt=1,
201
+ do_classifier_free_guidance=False)[0]
202
+ else:
203
+ pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt(
204
+ '', device=device, num_images_per_prompt=1,
205
+ do_classifier_free_guidance=False)[0::2]
206
+
207
+ return pipe
208
+
209
+ pipe = setup_pipeline(model_version='v1-5')
210
+ pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
211
+
212
+ def inpaint(image, inpaint_mask):
213
+ """Perform efficient inpainting on masked regions using Stable Diffusion.
214
+
215
+ Args:
216
+ image (ndarray): Input RGB image array (warped preview image)
217
+ inpaint_mask (ndarray): Binary mask array where 255 indicates regions to inpaint
218
+
219
+ Returns:
220
+ ndarray: Inpainted image with masked regions filled in
221
+ """
222
+ if image is None:
223
+ return None
224
+
225
+ if inpaint_mask is None:
226
+ return image
227
+
228
+ start = perf_counter()
229
+ pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
230
+ inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
231
+
232
+ # Convert inputs to PIL
233
+ image_pil = Image.fromarray(image)
234
+ inpaint_mask_pil = Image.fromarray(inpaint_mask)
235
+
236
+ width, height = inpaint_mask_pil.size
237
+ if width % 8 != 0 or height % 8 != 0:
238
+ width, height = round(width / 8) * 8, round(height / 8) * 8
239
+ image_pil = image_pil.resize((width, height))
240
+ image = np.array(image_pil)
241
+ inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST)
242
+ inpaint_mask = np.array(inpaint_mask_pil)
243
+
244
+ # Common pipeline parameters
245
+ common_params = {
246
+ 'image': image_pil,
247
+ 'mask_image': inpaint_mask_pil,
248
+ 'height': height,
249
+ 'width': width,
250
+ 'guidance_scale': 1.0,
251
+ 'num_inference_steps': 8,
252
+ 'strength': inpaint_strength,
253
+ 'output_type': 'np'
254
+ }
255
+
256
+ # Run pipeline
257
+ if pipe_id == 'v1-5':
258
+ inpainted = pipe(
259
+ prompt_embeds=pipe.cached_prompt_embeds,
260
+ **common_params
261
+ ).images[0]
262
+ else:
263
+ inpainted = pipe(
264
+ prompt_embeds=pipe.cached_prompt_embeds,
265
+ pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
266
+ **common_params
267
+ ).images[0]
268
+
269
+ # Post-process results
270
+ inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
271
+ return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)