liuyizhang
commited on
Commit
•
5247a47
1
Parent(s):
68cab41
update app.py
Browse files
app.py
CHANGED
@@ -242,7 +242,8 @@ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
|
242 |
|
243 |
# initialize SAM
|
244 |
logger.info(f"initialize SAM model...")
|
245 |
-
|
|
|
246 |
sam_predictor = SamPredictor(sam_model)
|
247 |
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
248 |
|
@@ -558,7 +559,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
558 |
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
559 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
560 |
|
561 |
-
boxes_filt = boxes_filt.
|
562 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
563 |
|
564 |
masks, _, _, _ = sam_predictor.predict_torch(
|
|
|
242 |
|
243 |
# initialize SAM
|
244 |
logger.info(f"initialize SAM model...")
|
245 |
+
sam_device = device
|
246 |
+
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
247 |
sam_predictor = SamPredictor(sam_model)
|
248 |
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
249 |
|
|
|
559 |
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
560 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
561 |
|
562 |
+
boxes_filt = boxes_filt.to(sam_device)
|
563 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
564 |
|
565 |
masks, _, _, _ = sam_predictor.predict_torch(
|