Spaces:
Runtime error
Runtime error
File size: 4,759 Bytes
cfa5142 d3e66e1 cfa5142 8d52a7d cfa5142 8d52a7d cfa5142 8d52a7d cfa5142 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import os
from datetime import datetime
import numpy as np
from modules.model_downloader import (
AVAILABLE_MODELS, DEFAULT_MODEL_TYPE, OUTPUT_DIR,
is_sam_exist,
download_sam_model_url
)
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
from modules.mask_utils import (
save_psd_with_masks,
create_mask_combined_images,
create_mask_gallery
)
CONFIGS = {
"sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
"sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
"sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
"sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
}
class SamInference:
def __init__(self,
model_dir: str = MODELS_DIR,
output_dir: str = OUTPUT_DIR
):
self.model = None
self.available_models = list(AVAILABLE_MODELS.keys())
self.model_type = DEFAULT_MODEL_TYPE
self.model_dir = model_dir
self.output_dir = output_dir
self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.mask_generator = None
self.image_predictor = None
# Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb
self.maskgen_hparams = {
"points_per_side": 64,
"points_per_batch": 128,
"pred_iou_thresh": 0.7,
"stability_score_thresh": 0.92,
"stability_score_offset": 0.7,
"crop_n_layers": 1,
"box_nms_thresh": 0.7,
"crop_n_points_downscale_factor": 2,
"min_mask_region_area": 25.0,
"use_m2m": True,
}
def load_model(self):
config = CONFIGS[self.model_type]
filename, url = AVAILABLE_MODELS[self.model_type]
model_path = os.path.join(self.model_dir, filename)
if not is_sam_exist(self.model_type):
print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...")
download_sam_model_url(self.model_type)
print("\nLayer Divider Extension : applying configs to model..")
try:
self.model = build_sam2(
config_file=config,
ckpt_path=model_path,
device=self.device
)
except Exception as e:
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
def set_predictors(self):
if self.model is None:
self.load_model()
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
self.mask_generator = SAM2AutomaticMaskGenerator(
model=self.model,
**self.maskgen_hparams
)
def generate_mask(self,
image: np.ndarray):
return self.mask_generator.generate(image)
def generate_mask_app(self,
image: np.ndarray,
model_type: str,
*params
):
maskgen_hparams = {
'points_per_side': int(params[0]),
'points_per_batch': int(params[1]),
'pred_iou_thresh': float(params[2]),
'stability_score_thresh': float(params[3]),
'stability_score_offset': float(params[4]),
'crop_n_layers': int(params[5]),
'box_nms_thresh': float(params[6]),
'crop_n_points_downscale_factor': int(params[7]),
'min_mask_region_area': int(params[8]),
'use_m2m': bool(params[9])
}
timestamp = datetime.now().strftime("%m%d%H%M%S")
output_file_name = f"result-{timestamp}.psd"
output_path = os.path.join(self.output_dir, "psd", output_file_name)
if self.model is None or self.model_type != model_type:
self.model_type = model_type
self.load_model()
if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams:
self.maskgen_hparams = maskgen_hparams
self.set_predictors()
masks = self.mask_generator.generate(image)
save_psd_with_masks(image, masks, output_path)
combined_image = create_mask_combined_images(image, masks)
gallery = create_mask_gallery(image, masks)
return [combined_image] + gallery, output_path
|