File size: 4,420 Bytes
84a883f
924419a
6d1f220
fa76f5f
924419a
6c70a52
fa76f5f
4d26ef2
924419a
84a883f
43d87b3
 
 
6f2f547
9b0c3be
924419a
9b0c3be
4e35839
 
9c09a5a
6f2f547
 
9b0c3be
6f2f547
 
 
 
 
 
 
 
924419a
6f2f547
4e35839
6f2f547
 
9c09a5a
6f2f547
84a883f
 
 
 
 
 
 
 
 
43d87b3
4e35839
84a883f
 
924419a
 
 
4d26ef2
84a883f
 
 
 
924419a
84a883f
4d26ef2
 
6f2f547
924419a
9c09a5a
 
 
6f2f547
 
 
 
 
 
 
 
9b0c3be
6f2f547
fd9de0f
4d26ef2
 
 
 
 
 
 
 
 
 
 
3041467
 
4d26ef2
fd9de0f
4d26ef2
 
 
 
3041467
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""functions using machine learning instance model(s)"""
from numpy import array as np_array, uint8, zeros, ndarray

from src import app_logger, MODEL_FOLDER
from src.io.geo_helpers import get_vectorized_raster_as_geojson
from src.io.tms2geotiff import download_extent
from src.prediction_api.sam_onnx import SegmentAnythingONNX
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
from src.utilities.type_hints import llist_float, dict_str_int, list_dict, tuple_ndarr_int, PIL_Image

models_dict = {"fastsam": {"instance": None}}


def samexporter_predict(
        bbox: llist_float,
        prompt: list_dict,
        zoom: float,
        model_name: str = "fastsam",
        url_tile: str = DEFAULT_TMS
) -> dict_str_int:
    """
    Return predictions as a geojson from a geo-referenced image using the given input prompt.

    1. if necessary instantiate a segment anything machine learning instance model
    2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
    3. get a prediction image from the segment anything instance model using the input prompt
    4. get a geo-referenced geojson from the prediction image

    Args:
        bbox: coordinates bounding box
        prompt: machine learning input prompt
        zoom: Level of detail
        model_name: machine learning model name
        url_tile: server url tile

    Returns:
        Affine transform
    """
    if models_dict[model_name]["instance"] is None:
        app_logger.info(f"missing instance model {model_name}, instantiating it now!")
        model_instance = SegmentAnythingONNX(
            encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
            decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
        )
        models_dict[model_name]["instance"] = model_instance
    app_logger.debug(f"using a {model_name} instance model...")
    models_instance = models_dict[model_name]["instance"]

    app_logger.info(f'tile_source: {url_tile}!')
    pt0, pt1 = bbox
    app_logger.info(f"downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
    img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=url_tile)
    app_logger.info(
        f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")

    mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
    app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
    return {
        "n_predictions": n_predictions,
        **get_vectorized_raster_as_geojson(mask, transform)
    }


def get_raster_inference(
        img: PIL_Image or ndarray, prompt: list_dict, models_instance: SegmentAnythingONNX, model_name: str
     ) -> tuple_ndarr_int:
    """
    Wrapper for rasterio Affine from_gdal method

    Args:
        img: input PIL Image
        prompt: list of prompt dict
        models_instance: SegmentAnythingONNX instance model
        model_name: model name string

    Returns:
        raster prediction mask, prediction number
    """
    np_img = np_array(img)
    app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
    app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
    try:
        app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
    except Exception as e_shape:
        app_logger.error(f"e_shape:{e_shape}.")
    app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
                    f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
    embedding = models_instance.encode(np_img)
    app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
    inference_out = models_instance.predict_masks(embedding, prompt)
    len_inference_out = len(inference_out[0, :, :, :])
    app_logger.info(f"Created {len_inference_out} prediction_masks,"
                    f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
    mask = zeros((inference_out.shape[2], inference_out.shape[3]), dtype=uint8)
    for n, m in enumerate(inference_out[0, :, :, :]):
        app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
                         f" => mask shape:{mask.shape}, {mask.dtype}.")
        mask[m > 0.0] = 255
    return mask, len_inference_out