jiuface commited on
Commit
efc40ec
1 Parent(s): 69ffe98

init project

Browse files
app.py CHANGED
@@ -1,142 +1,419 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
  from diffusers import DiffusionPipeline
 
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
13
- else:
14
- torch_dtype = torch.float32
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 1024
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
 
 
24
 
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
29
 
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- return image, seed
41
-
42
- examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
46
- ]
47
-
48
- css="""
49
- #col-container {
50
- margin: 0 auto;
51
- max-width: 640px;
52
- }
53
- """
54
-
55
- with gr.Blocks(css=css) as demo:
56
 
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
- """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
 
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
94
 
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
  )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
  )
112
-
113
- with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
  )
122
 
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  step=1,
128
- value=2, #Replace with defaults that work for your model
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
140
- )
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  demo.queue().launch()
 
1
+ from typing import Tuple, Optional
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import random
6
+ import spaces
7
  from diffusers import DiffusionPipeline
8
+ from diffusers import FluxInpaintPipeline
9
  import torch
10
+ from PIL import Image, ImageFilter
11
+ from huggingface_hub import login
12
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
13
+ import copy
14
+ import random
15
+ import time
16
+ import boto3
17
+ from io import BytesIO
18
+ from datetime import datetime
19
+ from diffusers.utils import load_image
20
+ import json
21
+
22
+ from utils.florence import load_florence_model, run_florence_inference, \
23
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
24
+ from utils.sam import load_sam_image_model, run_sam_inference
25
+
26
 
 
 
27
 
28
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
29
 
30
+ login(token=HF_TOKEN)
 
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
+ IMAGE_SIZE = 1024
34
 
35
+ # init
36
+ dtype = torch.bfloat16
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ base_model = "black-forest-labs/FLUX.1-dev"
39
 
40
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
41
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
42
+ pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
43
+
44
+ class calculateDuration:
45
+ def __init__(self, activity_name=""):
46
+ self.activity_name = activity_name
47
+
48
+ def __enter__(self):
49
+ self.start_time = time.time()
50
+ return self
51
 
52
+ def __exit__(self, exc_type, exc_value, traceback):
53
+ self.end_time = time.time()
54
+ self.elapsed_time = self.end_time - self.start_time
55
+ if self.activity_name:
56
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
57
+ else:
58
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
59
+
60
+
61
+ def calculate_image_dimensions_for_flux(
62
+ original_resolution_wh: Tuple[int, int],
63
+ maximum_dimension: int = IMAGE_SIZE
64
+ ) -> Tuple[int, int]:
65
+ width, height = original_resolution_wh
66
+
67
+ if width > height:
68
+ scaling_factor = maximum_dimension / width
69
+ else:
70
+ scaling_factor = maximum_dimension / height
71
+
72
+ new_width = int(width * scaling_factor)
73
+ new_height = int(height * scaling_factor)
74
+
75
+ new_width = new_width - (new_width % 32)
76
+ new_height = new_height - (new_height % 32)
77
+
78
+ return new_width, new_height
79
+
80
+ def is_mask_empty(image: Image.Image) -> bool:
81
+ gray_img = image.convert("L")
82
+ pixels = list(gray_img.getdata())
83
+ return all(pixel == 0 for pixel in pixels)
84
+
85
+ def process_mask(
86
+ mask: Image.Image,
87
+ mask_inflation: Optional[int] = None,
88
+ mask_blur: Optional[int] = None
89
+ ) -> Image.Image:
90
+ """
91
+ Inflates and blurs the white regions of a mask.
92
+ Args:
93
+ mask (Image.Image): The input mask image.
94
+ mask_inflation (Optional[int]): The number of pixels to inflate the mask by.
95
+ mask_blur (Optional[int]): The radius of the Gaussian blur to apply.
96
+ Returns:
97
+ Image.Image: The processed mask with inflated and/or blurred regions.
98
+ """
99
+ if mask_inflation and mask_inflation > 0:
100
+ mask_array = np.array(mask)
101
+ kernel = np.ones((mask_inflation, mask_inflation), np.uint8)
102
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
103
+ mask = Image.fromarray(mask_array)
104
+
105
+ if mask_blur and mask_blur > 0:
106
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
107
+
108
+ return mask
109
+
110
+ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
111
+ print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
112
+ connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
113
+
114
+ s3 = boto3.client(
115
+ 's3',
116
+ endpoint_url=connectionUrl,
117
+ region_name='auto',
118
+ aws_access_key_id=access_key,
119
+ aws_secret_access_key=secret_key
120
+ )
121
+
122
+ current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
123
+ image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
124
+ buffer = BytesIO()
125
+ image.save(buffer, "PNG")
126
+ buffer.seek(0)
127
+ s3.upload_fileobj(buffer, bucket_name, image_file)
128
+ print("upload finish", image_file)
129
+ return image_file
130
+
131
+
132
+ @spaces.GPU(duration=50)
133
+ def run_flux(
134
+ image: Image.Image,
135
+ mask: Image.Image,
136
+ prompt: str,
137
+ lora_path: str,
138
+ lora_weights: str,
139
+ lora_scale: float,
140
+ seed_slicer: int,
141
+ randomize_seed_checkbox: bool,
142
+ strength_slider: float,
143
+ num_inference_steps_slider: int,
144
+ resolution_wh: Tuple[int, int],
145
+ ) -> Image.Image:
146
+ print("Running FLUX...")
147
+
148
+ with calculateDuration("load lora"):
149
+ print("start to load lora", lora_path, lora_weights)
150
+ pipe.load_lora_weights(lora_path, weight_name=lora_weights)
151
+
152
+ width, height = resolution_wh
153
+ if randomize_seed_checkbox:
154
+ seed_slicer = random.randint(0, MAX_SEED)
155
+ generator = torch.Generator().manual_seed(seed_slicer)
156
+
157
+ return PIPE(
158
+ prompt=prompt,
159
+ image=image,
160
+ mask_image=mask,
161
+ width=width,
162
+ height=height,
163
+ strength=strength_slider,
164
+ generator=generator,
165
+ num_inference_steps=num_inference_steps_slider,
166
+ max_sequence_length=256,
167
+ joint_attention_kwargs={"scale": lora_scale},
168
+ ).images[0]
169
+
170
+
171
+ @spaces.GPU(duration=50)
172
+ def genearte_mask(image: Image.Image, masking_prompt_text: str) -> Image.Image:
173
+ # generate mask by florence & sam
174
+ print("Generating mask...")
175
+
176
+ return
177
+
178
+
179
+ def process(
180
+ image_url: str,
181
+ inpainting_prompt_text: str,
182
+ masking_prompt_text: str,
183
+ mask_inflation_slider: int,
184
+ mask_blur_slider: int,
185
+ seed_slicer: int,
186
+ randomize_seed_checkbox: bool,
187
+ strength_slider: float,
188
+ num_inference_steps_slider: int,
189
+ lora_path: str,
190
+ lora_weights: str,
191
+ lora_scale: str,
192
+ upload_to_r2: bool,
193
+ account_id: str,
194
+ access_key: str,
195
+ secret_key: str,
196
+ bucket:str
197
+ ):
198
+ result = {"status": "false", "message": ""}
199
+ if not image_url:
200
+ gr.Info("please enter image url for inpaiting")
201
+ result["message"] = "invalid image url"
202
+ return None, None, json.dumps(result)
203
+
204
+ if not inpainting_prompt_text:
205
+ gr.Info("Please enter inpainting text prompt.")
206
+ result["message"] = "invalid inpainting prompt"
207
+ return None, None, json.dumps(result)
208
 
209
+ if not masking_prompt_text:
210
+ gr.Info("Please enter masking_prompt_text.")
211
+ result["message"] = "invalid masking prompt"
212
+ return None, None, json.dumps(result)
213
+
214
+
215
+ image = load_image(image_url)
216
+ mask = genearte_mask(image, masking_prompt_text)
 
 
 
 
 
 
 
 
217
 
218
+ if not image:
219
+ gr.Info("Please upload an image.")
220
+ result["message"] = "can not load image"
221
+ return None, None, json.dumps(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ if is_mask_empty(mask):
224
+ gr.Info("Please draw a mask or enter a masking prompt.")
225
+ result["message"] = "can not generate mask"
226
+ return None, None, json.dumps(result)
227
+
228
+ # generate
229
+ width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
230
+ image = image.resize((width, height), Image.LANCZOS)
231
+ mask = mask.resize((width, height), Image.LANCZOS)
232
+ mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
233
+ image = run_flux(
234
+ image=image,
235
+ mask=mask,
236
+ prompt=inpainting_prompt_text,
237
+ lora_path=lora_path,
238
+ lora_scale=lora_scale,
239
+ lora_weights=lora_weights,
240
+ seed_slicer=seed_slicer,
241
+ randomize_seed_checkbox=randomize_seed_checkbox,
242
+ strength_slider=strength_slider,
243
+ num_inference_steps_slider=num_inference_steps_slider,
244
+ resolution_wh=(width, height)
245
+ )
246
+ if upload_to_r2:
247
+ url = upload_image_to_r2(image, account_id, access_key, secret_key, bucket)
248
+ result = {"status": "success", "url": url}
249
+ else:
250
+ result = {"status": "success", "message": "Image generated but not uploaded"}
251
+
252
+ return image, mask, json.dumps(result)
253
+
254
+
255
+
256
+ with gr.Blocks() as demo:
257
+
258
+ with gr.Row():
259
+ with gr.Column():
260
 
261
+ image_url = gr.Text(
262
+ label="Image url for inpainting",
263
+ show_label=False,
264
+ max_lines=1,
265
+ placeholder="Enter image url for inpainting",
266
+ container=False,
267
+ )
268
 
269
+ masking_prompt_text_component = gr.Text(
270
+ label="Masking prompt",
271
+ show_label=False,
272
+ max_lines=1,
273
+ placeholder="Enter text to generate masking",
274
+ container=False,
 
 
275
  )
276
+
277
+ inpainting_prompt_text_component = gr.Text(
278
+ label="Inpainting prompt",
279
+ show_label=False,
280
+ max_lines=1,
281
+ placeholder="Enter text to generate inpainting",
282
+ container=False,
283
  )
284
+
285
+ submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
286
+
287
+ with gr.Accordion("Lora Settings", open=True):
288
+ lora_path = gr.Textbox(
289
+ label="Lora model path",
290
+ show_label=True,
291
+ max_lines=1,
292
+ placeholder="Enter your model path",
293
+ info="Currently, only LoRA hosted on Hugging Face'model can be loaded properly.",
294
+ value="XLabs-AI/flux-RealismLora"
295
+ )
296
+ lora_weights = gr.Textbox(
297
+ label="Lora weights",
298
+ show_label=True,
299
+ max_lines=1,
300
+ placeholder="Enter your lora weights name",
301
+ value="lora.safetensors"
302
+ )
303
+ lora_scale = gr.Slider(
304
+ label="Lora scale",
305
+ show_label=True,
306
+ minimum=0,
307
+ maximum=1,
308
  step=0.1,
309
+ value=0.9,
310
  )
311
 
312
+
313
+
314
+ with gr.Accordion("Advanced Settings", open=False):
315
+
316
+
317
+ with gr.Row():
318
+ mask_inflation_slider_component = gr.Slider(
319
+ label="Mask inflation",
320
+ info="Adjusts the amount of mask edge expansion before "
321
+ "inpainting.",
322
+ minimum=0,
323
+ maximum=20,
324
+ step=1,
325
+ value=5,
326
+ )
327
+
328
+ mask_blur_slider_component = gr.Slider(
329
+ label="Mask blur",
330
+ info="Controls the intensity of the Gaussian blur applied to "
331
+ "the mask edges.",
332
+ minimum=0,
333
+ maximum=20,
334
+ step=1,
335
+ value=5,
336
+ )
337
+
338
+ seed_slicer_component = gr.Slider(
339
+ label="Seed",
340
+ minimum=0,
341
+ maximum=MAX_SEED,
342
  step=1,
343
+ value=42,
344
  )
345
+
346
+ randomize_seed_checkbox_component = gr.Checkbox(
347
+ label="Randomize seed", value=True)
348
+
349
+ with gr.Row():
350
+
351
+ strength_slider_component = gr.Slider(
352
+ label="Strength",
353
+ info="Indicates extent to transform the reference `image`. "
354
+ "Must be between 0 and 1. `image` is used as a starting "
355
+ "point and more noise is added the higher the `strength`.",
356
+ minimum=0,
357
+ maximum=1,
358
+ step=0.01,
359
+ value=0.85,
360
+ )
361
+
362
+ num_inference_steps_slider_component = gr.Slider(
363
+ label="Number of inference steps",
364
+ info="The number of denoising steps. More denoising steps "
365
+ "usually lead to a higher quality image at the",
366
+ minimum=1,
367
+ maximum=50,
368
+ step=1,
369
+ value=20,
370
+ )
371
+
372
+ upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
373
+ account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
374
+ access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
375
+ secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
376
+ bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
377
 
378
+
379
+ with gr.Column():
 
 
 
 
 
 
 
 
380
 
381
+ output_image_component = gr.Image(
382
+ type='pil', image_mode='RGB', label='Generated image', format="png")
383
+
384
+
385
+
386
+ with gr.Accordion("Debug", open=False):
387
+ output_mask_component = gr.Image(
388
+ type='pil', image_mode='RGB', label='Input mask', format="png")
389
+
390
+ output_json_component = gr.Textbox()
391
+
392
+ submit_button_component.click(
393
+ fn=process,
394
+ inputs=[
395
+ image_url,
396
+ inpainting_prompt_text_component,
397
+ masking_prompt_text_component,
398
+ mask_inflation_slider_component,
399
+ mask_blur_slider_component,
400
+ seed_slicer_component,
401
+ randomize_seed_checkbox_component,
402
+ strength_slider_component,
403
+ num_inference_steps_slider_component,
404
+ lora_path,
405
+ lora_weights,
406
+ lora_scale,
407
+ upload_to_r2,
408
+ account_id,
409
+ access_key,
410
+ secret_key,
411
+ bucket
412
+ ],
413
+ outputs=[
414
+ output_image_component,
415
+ output_mask_component,
416
+ output_json_component
417
+ ]
418
+ )
419
  demo.queue().launch()
configs/__init__.py ADDED
File without changes
configs/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
configs/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
configs/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
configs/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
requirements.txt CHANGED
@@ -1,6 +1,17 @@
1
  accelerate
2
- diffusers
3
  invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  accelerate
 
2
  invisible_watermark
3
  torch
4
  transformers
5
+ xformers
6
+ tqdm
7
+ einops
8
+ spaces
9
+ timm
10
+ samv2
11
+ gradio
12
+ supervision
13
+ opencv-python
14
+ pytest
15
+ requests
16
+ git+https://github.com/Gothos/diffusers.git@flux-inpaint
17
+ boto3
utils/__init__.py ADDED
File without changes
utils/florence.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ # FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-large-ft"
12
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
13
+ FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
14
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
15
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
16
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
17
+
18
+
19
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
20
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
21
+ if not str(filename).endswith("/modeling_florence2.py"):
22
+ return get_imports(filename)
23
+ imports = get_imports(filename)
24
+ imports.remove("flash_attn")
25
+ return imports
26
+
27
+
28
+ def load_florence_model(
29
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
30
+ ) -> Tuple[Any, Any]:
31
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ checkpoint, trust_remote_code=True).to(device).eval()
34
+ processor = AutoProcessor.from_pretrained(
35
+ checkpoint, trust_remote_code=True)
36
+ return model, processor
37
+
38
+
39
+ def run_florence_inference(
40
+ model: Any,
41
+ processor: Any,
42
+ device: torch.device,
43
+ image: Image,
44
+ task: str,
45
+ text: str = None
46
+ ) -> Tuple[str, Dict]:
47
+ if text:
48
+ prompt = task + text
49
+ else:
50
+ prompt = task
51
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
52
+ print(inputs)
53
+ generated_ids = model.generate(
54
+ input_ids=inputs["input_ids"],
55
+ pixel_values=inputs["pixel_values"],
56
+ max_new_tokens=1024,
57
+ num_beams=3
58
+ )
59
+ generated_text = processor.batch_decode(
60
+ generated_ids, skip_special_tokens=False)[0]
61
+ response = processor.post_process_generation(
62
+ generated_text, task=task, image_size=image.size)
63
+ print(generated_text, response)
64
+ return generated_text, response
utils/sam.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
+ import torch
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+
10
+ # SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
11
+ # SAM_CONFIG = "sam2_hiera_s.yaml"
12
+ SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
13
+ SAM_CONFIG = "sam2_hiera_l.yaml"
14
+
15
+
16
+ def load_sam_image_model(
17
+ device: torch.device,
18
+ config: str = SAM_CONFIG,
19
+ checkpoint: str = SAM_CHECKPOINT
20
+ ) -> SAM2ImagePredictor:
21
+ model = build_sam2(config, checkpoint, device=device)
22
+ return SAM2ImagePredictor(sam_model=model)
23
+
24
+
25
+ def load_sam_video_model(
26
+ device: torch.device,
27
+ config: str = SAM_CONFIG,
28
+ checkpoint: str = SAM_CHECKPOINT
29
+ ) -> Any:
30
+ return build_sam2_video_predictor(config, checkpoint, device=device)
31
+
32
+
33
+ def run_sam_inference(
34
+ model: Any,
35
+ image: Image,
36
+ detections: sv.Detections
37
+ ) -> sv.Detections:
38
+ image = np.array(image.convert("RGB"))
39
+ model.set_image(image)
40
+ # from left to right
41
+ bboxes = detections.xyxy
42
+ bboxes = sorted(bboxes, key=lambda bbox: bbox[0])
43
+ mask, score, _ = model.predict(box=bboxes, multimask_output=False)
44
+
45
+ # dirty fix; remove this later
46
+ if len(mask.shape) == 4:
47
+ mask = np.squeeze(mask)
48
+
49
+ detections.mask = mask.astype(bool)
50
+ return detections