File size: 4,420 Bytes
84a883f 924419a 6d1f220 fa76f5f 924419a 6c70a52 fa76f5f 4d26ef2 924419a 84a883f 43d87b3 6f2f547 9b0c3be 924419a 9b0c3be 4e35839 9c09a5a 6f2f547 9b0c3be 6f2f547 924419a 6f2f547 4e35839 6f2f547 9c09a5a 6f2f547 84a883f 43d87b3 4e35839 84a883f 924419a 4d26ef2 84a883f 924419a 84a883f 4d26ef2 6f2f547 924419a 9c09a5a 6f2f547 9b0c3be 6f2f547 fd9de0f 4d26ef2 3041467 4d26ef2 fd9de0f 4d26ef2 3041467 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
"""functions using machine learning instance model(s)"""
from numpy import array as np_array, uint8, zeros, ndarray
from src import app_logger, MODEL_FOLDER
from src.io.geo_helpers import get_vectorized_raster_as_geojson
from src.io.tms2geotiff import download_extent
from src.prediction_api.sam_onnx import SegmentAnythingONNX
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
from src.utilities.type_hints import llist_float, dict_str_int, list_dict, tuple_ndarr_int, PIL_Image
models_dict = {"fastsam": {"instance": None}}
def samexporter_predict(
bbox: llist_float,
prompt: list_dict,
zoom: float,
model_name: str = "fastsam",
url_tile: str = DEFAULT_TMS
) -> dict_str_int:
"""
Return predictions as a geojson from a geo-referenced image using the given input prompt.
1. if necessary instantiate a segment anything machine learning instance model
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
3. get a prediction image from the segment anything instance model using the input prompt
4. get a geo-referenced geojson from the prediction image
Args:
bbox: coordinates bounding box
prompt: machine learning input prompt
zoom: Level of detail
model_name: machine learning model name
url_tile: server url tile
Returns:
Affine transform
"""
if models_dict[model_name]["instance"] is None:
app_logger.info(f"missing instance model {model_name}, instantiating it now!")
model_instance = SegmentAnythingONNX(
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
)
models_dict[model_name]["instance"] = model_instance
app_logger.debug(f"using a {model_name} instance model...")
models_instance = models_dict[model_name]["instance"]
app_logger.info(f'tile_source: {url_tile}!')
pt0, pt1 = bbox
app_logger.info(f"downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=url_tile)
app_logger.info(
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
return {
"n_predictions": n_predictions,
**get_vectorized_raster_as_geojson(mask, transform)
}
def get_raster_inference(
img: PIL_Image or ndarray, prompt: list_dict, models_instance: SegmentAnythingONNX, model_name: str
) -> tuple_ndarr_int:
"""
Wrapper for rasterio Affine from_gdal method
Args:
img: input PIL Image
prompt: list of prompt dict
models_instance: SegmentAnythingONNX instance model
model_name: model name string
Returns:
raster prediction mask, prediction number
"""
np_img = np_array(img)
app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
try:
app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
except Exception as e_shape:
app_logger.error(f"e_shape:{e_shape}.")
app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
embedding = models_instance.encode(np_img)
app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
inference_out = models_instance.predict_masks(embedding, prompt)
len_inference_out = len(inference_out[0, :, :, :])
app_logger.info(f"Created {len_inference_out} prediction_masks,"
f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
mask = zeros((inference_out.shape[2], inference_out.shape[3]), dtype=uint8)
for n, m in enumerate(inference_out[0, :, :, :]):
app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
f" => mask shape:{mask.shape}, {mask.dtype}.")
mask[m > 0.0] = 255
return mask, len_inference_out
|