File size: 4,401 Bytes
fa76f5f
3041467
 
 
6d1f220
 
fa76f5f
6c70a52
 
fa76f5f
4d26ef2
3041467
2b9a42c
fa76f5f
43d87b3
 
 
5b88544
fa76f5f
43d87b3
5b88544
43d87b3
2f5b9e0
 
 
43d87b3
5b88544
43d87b3
 
6c70a52
 
 
 
 
5b88544
3041467
 
 
6c70a52
3041467
4d26ef2
6c70a52
 
 
 
3041467
6c70a52
4d26ef2
 
 
 
 
 
 
 
 
 
 
 
3041467
 
 
 
 
 
4d26ef2
 
 
 
 
3041467
 
4d26ef2
 
 
 
 
 
3041467
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Press the green button in the gutter to run the script.
import json
import tempfile

import numpy as np

from src import app_logger, MODEL_FOLDER
from src.io.geo_helpers import get_vectorized_raster_as_geojson, get_affine_transform_from_gdal
from src.io.tms2geotiff import download_extent
from src.prediction_api.sam_onnx import SegmentAnythingONNX
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
from src.utilities.serialize import serialize


models_dict = {"fastsam": {"instance": None}}


def samexporter_predict(bbox, prompt: list[dict], zoom: float, model_name: str = "fastsam") -> dict:
    try:
        if models_dict[model_name]["instance"] is None:
            app_logger.info(f"missing instance model {model_name}, instantiating it now!")
            model_instance = SegmentAnythingONNX(
                encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
                decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
            )
            models_dict[model_name]["instance"] = model_instance
        app_logger.debug(f"using a {model_name} instance model...")
        models_instance = models_dict[model_name]["instance"]

        app_logger.info(f'tile_source: {DEFAULT_TMS}!')
        pt0, pt1 = bbox
        app_logger.info(f"downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
        img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
        app_logger.info(f"img type {type(img)} with shape/size:{img.size}, matrix:{matrix}.")

        with tempfile.NamedTemporaryFile(mode='w', prefix=f"matrix_", delete=False) as temp_f1:
            json.dump({"matrix": serialize(matrix)}, temp_f1)

        transform = get_affine_transform_from_gdal(matrix)
        app_logger.debug(f"transform to consume with rasterio.shapes: {type(transform)}, {transform}.")

        mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
        app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
        return {
            "n_predictions": n_predictions,
            **get_vectorized_raster_as_geojson(mask, matrix)
        }
    except ImportError as e_import_module:
        app_logger.error(f"Error trying import module:{e_import_module}.")


def get_raster_inference(img, prompt, models_instance, model_name):
    np_img = np.array(img)
    app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
    app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
    try:
        app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
    except Exception as e_shape:
        app_logger.error(f"e_shape:{e_shape}.")
    try:
        with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__img_", delete=False) as temp_f0:
            np.save(str(temp_f0.file.name), np_img)
    except Exception as e_save:
        app_logger.error(f"e_save:{e_save}.")
        raise e_save
    app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
                    f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
    embedding = models_instance.encode(np_img)
    app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
    inference_out = models_instance.predict_masks(embedding, prompt)
    len_inference_out = len(inference_out[0, :, :, :])
    app_logger.info(f"Created {len_inference_out} prediction_masks,"
                    f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
    mask = np.zeros((inference_out.shape[2], inference_out.shape[3]), dtype=np.uint8)
    for n, m in enumerate(inference_out[0, :, :, :]):
        app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
                         f" => mask shape:{mask.shape}, {mask.dtype}.")
        mask[m > 0.0] = 255
    try:
        with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__mask_", delete=False) as temp_f1:
            np.save(temp_f1.file.name, mask)
        with tempfile.NamedTemporaryFile(mode='w', prefix=f"get_raster_inference__inference_out_", delete=False) as temp_f2:
            np.save(temp_f2.file.name, inference_out)
    except Exception as e_save1:
        app_logger.error(f"e_save1:{e_save1}.")
        raise e_save1
    return mask, len_inference_out