File size: 5,428 Bytes
fa76f5f
 
43d87b3
fa76f5f
 
 
 
43d87b3
fa76f5f
 
 
 
43d87b3
 
 
fa76f5f
43d87b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa76f5f
 
2f5b9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
43d87b3
fa76f5f
2f5b9e0
 
 
43d87b3
 
 
2f5b9e0
 
 
43d87b3
 
 
 
 
9271aef
43d87b3
 
 
 
 
 
 
9271aef
43d87b3
9271aef
 
43d87b3
9271aef
43d87b3
 
 
9271aef
 
43d87b3
 
 
 
 
 
9271aef
 
43d87b3
 
9271aef
43d87b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f5b9e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Press the green button in the gutter to run the script.
import numpy as np
from typing import List

from src import app_logger, MODEL_FOLDER
from src.io.tms2geotiff import download_extent
from src.prediction_api.sam_onnx import SegmentAnythingONNX
from src.utilities.constants import MODEL_ENCODER_NAME, ZOOM, DEFAULT_TMS, MODEL_DECODER_NAME
from src.utilities.serialize import serialize
from src.utilities.type_hints import input_float_tuples


models_dict = {"fastsam": {"instance": None}}


def zip_arrays(arr1, arr2):
    try:
        arr1_list = arr1.tolist()
        arr2_list = arr2.tolist()
        # return {serialize(k): serialize(v) for k, v in zip(arr1_list, arr2_list)}
        d = {}
        for n1, n2 in zip(arr1_list, arr2_list):
            app_logger.info(f"n1:{n1}, type {type(n1)}, n2:{n2}, type {type(n2)}.")
            n1f = str(n1)
            n2f = str(n2)
            app_logger.info(f"n1:{n1}=>{n1f}, n2:{n2}=>{n2f}.")
            d[n1f] = n2f
        app_logger.info(f"zipped dict:{d}.")
        return d
    except Exception as e_zip_arrays:
        app_logger.info(f"exception zip_arrays:{e_zip_arrays}.")
        return {}


def load_affine_transformation_from_matrix(matrix_source_coeffs: List):
    from affine import Affine

    if len(matrix_source_coeffs) != 6:
        raise ValueError(f"Expected 6 coefficients, found {len(matrix_source_coeffs)}; argument type: {type(matrix_source_coeffs)}.")

    try:
        a, d, b, e, c, f = (float(x) for x in matrix_source_coeffs)
        center = tuple.__new__(Affine, [a, b, c, d, e, f, 0.0, 0.0, 1.0])
        return center * Affine.translation(-0.5, -0.5)
    except Exception as e:
        app_logger.error(f"exception:{e}, check https://github.com/rasterio/affine project for updates")


def samexporter_predict(bbox: input_float_tuples, prompt: list[dict], zoom: float = ZOOM, model_name: str = "fastsam") -> dict:
    try:
        from rasterio.features import shapes
        from geopandas import GeoDataFrame

        if models_dict[model_name]["instance"] is None:
            app_logger.info(f"missing instance model {model_name}, instantiating it now")
            model_instance = SegmentAnythingONNX(
                encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
                decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
            )
            models_dict[model_name]["instance"] = model_instance
        app_logger.info(f"using a {model_name} instance model...")
        models_instance = models_dict[model_name]["instance"]

        for coord in bbox:
            app_logger.debug(f"bbox coord:{coord}, type:{type(coord)}.")
        app_logger.info(f"start download_extent using bbox:{bbox}, type:{type(bbox)}, download image...")

        pt0 = bbox[0]
        pt1 = bbox[1]
        img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)

        app_logger.info(f"img type {type(img)}, matrix type {type(matrix)}.")
        app_logger.debug(f"matrix values: {serialize(matrix)}.")
        np_img = np.array(img)
        app_logger.debug(f"np_img type {type(np_img)}.")
        app_logger.debug(f"np_img dtype {np_img.dtype}, shape {np_img.shape}.")
        app_logger.info(f"geotiff created with size/shape {img.size} and transform matrix {str(matrix)}, start to initialize SamGeo instance:")
        app_logger.info(f"use {model_name} model, ENCODER model {MODEL_ENCODER_NAME} and {MODEL_DECODER_NAME} from {MODEL_FOLDER}): model instantiated, creating embedding...")
        embedding = models_instance.encode(np_img)
        app_logger.info(f"embedding created, running predict_masks...")
        prediction_masks = models_instance.predict_masks(embedding, prompt)
        app_logger.debug(f"predict_masks terminated...")
        app_logger.info(f"predict_masks terminated, prediction masks shape:{prediction_masks.shape}, {prediction_masks.dtype}.")

        mask = np.zeros((prediction_masks.shape[2], prediction_masks.shape[3]), dtype=np.uint8)
        for m in prediction_masks[0, :, :, :]:
            mask[m > 0.0] = 255

        mask_unique_values, mask_unique_values_count = serialize(np.unique(mask, return_counts=True))
        app_logger.debug(f"mask_unique_values:{mask_unique_values}.")
        app_logger.debug(f"mask_unique_values_count:{mask_unique_values_count}.")

        transform = load_affine_transformation_from_matrix(matrix)
        app_logger.info(f"image/geojson origin matrix:{matrix}, transform:{transform}: create shapes_generator...")
        shapes_generator = ({
            'properties': {'raster_val': v}, 'geometry': s}
            for i, (s, v)
            in enumerate(shapes(mask, mask=mask, transform=transform))
        )
        shapes_list = list(shapes_generator)
        app_logger.info(f"created {len(shapes_list)} polygons.")
        gpd_polygonized_raster = GeoDataFrame.from_features(shapes_list, crs="EPSG:3857")
        app_logger.info(f"created a GeoDataFrame...")
        geojson = gpd_polygonized_raster.to_json(to_wgs84=True)
        app_logger.info(f"created geojson...")

        return {
            "geojson": geojson,
            "n_shapes_geojson": len(shapes_list),
            "n_predictions": len(prediction_masks),
            # "n_pixels_predictions": zip_arrays(mask_unique_values, mask_unique_values_count),
        }
    except ImportError as e:
        app_logger.error(f"Error trying import module:{e}.")