File size: 4,401 Bytes
fa76f5f 3041467 6d1f220 fa76f5f 6c70a52 fa76f5f 4d26ef2 3041467 2b9a42c fa76f5f 43d87b3 5b88544 fa76f5f 43d87b3 5b88544 43d87b3 2f5b9e0 43d87b3 5b88544 43d87b3 6c70a52 5b88544 3041467 6c70a52 3041467 4d26ef2 6c70a52 3041467 6c70a52 4d26ef2 3041467 4d26ef2 3041467 4d26ef2 3041467 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Press the green button in the gutter to run the script.
import json
import tempfile
import numpy as np
from src import app_logger, MODEL_FOLDER
from src.io.geo_helpers import get_vectorized_raster_as_geojson, get_affine_transform_from_gdal
from src.io.tms2geotiff import download_extent
from src.prediction_api.sam_onnx import SegmentAnythingONNX
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
from src.utilities.serialize import serialize
models_dict = {"fastsam": {"instance": None}}
def samexporter_predict(bbox, prompt: list[dict], zoom: float, model_name: str = "fastsam") -> dict:
try:
if models_dict[model_name]["instance"] is None:
app_logger.info(f"missing instance model {model_name}, instantiating it now!")
model_instance = SegmentAnythingONNX(
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
)
models_dict[model_name]["instance"] = model_instance
app_logger.debug(f"using a {model_name} instance model...")
models_instance = models_dict[model_name]["instance"]
app_logger.info(f'tile_source: {DEFAULT_TMS}!')
pt0, pt1 = bbox
app_logger.info(f"downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
app_logger.info(f"img type {type(img)} with shape/size:{img.size}, matrix:{matrix}.")
with tempfile.NamedTemporaryFile(mode='w', prefix=f"matrix_", delete=False) as temp_f1:
json.dump({"matrix": serialize(matrix)}, temp_f1)
transform = get_affine_transform_from_gdal(matrix)
app_logger.debug(f"transform to consume with rasterio.shapes: {type(transform)}, {transform}.")
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
return {
"n_predictions": n_predictions,
**get_vectorized_raster_as_geojson(mask, matrix)
}
except ImportError as e_import_module:
app_logger.error(f"Error trying import module:{e_import_module}.")
def get_raster_inference(img, prompt, models_instance, model_name):
np_img = np.array(img)
app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
try:
app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
except Exception as e_shape:
app_logger.error(f"e_shape:{e_shape}.")
try:
with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__img_", delete=False) as temp_f0:
np.save(str(temp_f0.file.name), np_img)
except Exception as e_save:
app_logger.error(f"e_save:{e_save}.")
raise e_save
app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
embedding = models_instance.encode(np_img)
app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
inference_out = models_instance.predict_masks(embedding, prompt)
len_inference_out = len(inference_out[0, :, :, :])
app_logger.info(f"Created {len_inference_out} prediction_masks,"
f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
mask = np.zeros((inference_out.shape[2], inference_out.shape[3]), dtype=np.uint8)
for n, m in enumerate(inference_out[0, :, :, :]):
app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
f" => mask shape:{mask.shape}, {mask.dtype}.")
mask[m > 0.0] = 255
try:
with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__mask_", delete=False) as temp_f1:
np.save(temp_f1.file.name, mask)
with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__inference_out_", delete=False) as temp_f2:
np.save(temp_f2.file.name, inference_out)
except Exception as e_save1:
app_logger.error(f"e_save1:{e_save1}.")
raise e_save1
return mask, len_inference_out
|