[test] add test cases for get_vectorized_raster_as_geojson and SegmentAnythingONNX
Browse files- src/io/geo_helpers.py +1 -1
- src/prediction_api/predictors.py +24 -26
- tests/io/test_geo_helpers.py +24 -1
- tests/prediction_api/test_predictors.py +34 -6
src/io/geo_helpers.py
CHANGED
@@ -31,7 +31,7 @@ def load_affine_transformation_from_matrix(matrix_source_coeffs: List[float]) ->
|
|
31 |
raise e
|
32 |
|
33 |
|
34 |
-
def get_affine_transform_from_gdal(matrix_source_coeffs: List[float]) -> Affine:
|
35 |
"""wrapper for rasterio Affine from_gdal method
|
36 |
|
37 |
Args:
|
|
|
31 |
raise e
|
32 |
|
33 |
|
34 |
+
def get_affine_transform_from_gdal(matrix_source_coeffs: List[float] or Tuple[float]) -> Affine:
|
35 |
"""wrapper for rasterio Affine from_gdal method
|
36 |
|
37 |
Args:
|
src/prediction_api/predictors.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
"""functions using
|
2 |
from typing import Dict, Tuple
|
3 |
from PIL.Image import Image
|
4 |
import numpy as np
|
@@ -10,6 +10,7 @@ from src.prediction_api.sam_onnx import SegmentAnythingONNX
|
|
10 |
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
|
11 |
from src.utilities.type_hints import llist_float
|
12 |
|
|
|
13 |
models_dict = {"fastsam": {"instance": None}}
|
14 |
|
15 |
|
@@ -36,34 +37,31 @@ def samexporter_predict(
|
|
36 |
Returns:
|
37 |
dict: Affine transform
|
38 |
"""
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
models_instance = models_dict[model_name]["instance"]
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
except ImportError as e_import_module:
|
66 |
-
app_logger.error(f"Error trying import module:{e_import_module}.")
|
67 |
|
68 |
|
69 |
def get_raster_inference(
|
|
|
1 |
+
"""functions using machine learning instance model(s)"""
|
2 |
from typing import Dict, Tuple
|
3 |
from PIL.Image import Image
|
4 |
import numpy as np
|
|
|
10 |
from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_TMS
|
11 |
from src.utilities.type_hints import llist_float
|
12 |
|
13 |
+
|
14 |
models_dict = {"fastsam": {"instance": None}}
|
15 |
|
16 |
|
|
|
37 |
Returns:
|
38 |
dict: Affine transform
|
39 |
"""
|
40 |
+
if models_dict[model_name]["instance"] is None:
|
41 |
+
app_logger.info(f"missing instance model {model_name}, instantiating it now!")
|
42 |
+
model_instance = SegmentAnythingONNX(
|
43 |
+
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
|
44 |
+
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
|
45 |
+
)
|
46 |
+
models_dict[model_name]["instance"] = model_instance
|
47 |
+
app_logger.debug(f"using a {model_name} instance model...")
|
48 |
+
models_instance = models_dict[model_name]["instance"]
|
|
|
49 |
|
50 |
+
app_logger.info(f'tile_source: {DEFAULT_TMS}!')
|
51 |
+
pt0, pt1 = bbox
|
52 |
+
app_logger.info(f"downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
53 |
+
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
|
54 |
+
app_logger.info(f"img type {type(img)} with shape/size:{img.size}, matrix:{type(matrix)}, matrix:{matrix}.")
|
55 |
|
56 |
+
transform = get_affine_transform_from_gdal(matrix)
|
57 |
+
app_logger.debug(f"transform to consume with rasterio.shapes: {type(transform)}, {transform}.")
|
58 |
|
59 |
+
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
|
60 |
+
app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
|
61 |
+
return {
|
62 |
+
"n_predictions": n_predictions,
|
63 |
+
**get_vectorized_raster_as_geojson(mask, matrix)
|
64 |
+
}
|
|
|
|
|
65 |
|
66 |
|
67 |
def get_raster_inference(
|
tests/io/test_geo_helpers.py
CHANGED
@@ -62,7 +62,7 @@ class TestGeoHelpers(unittest.TestCase):
|
|
62 |
"check https://github.com/rasterio/affine project for updates")
|
63 |
raise e
|
64 |
|
65 |
-
def
|
66 |
from src.io.geo_helpers import get_vectorized_raster_as_geojson
|
67 |
|
68 |
name_fn = "samexporter_predict"
|
@@ -78,3 +78,26 @@ class TestGeoHelpers(unittest.TestCase):
|
|
78 |
output_geojson = shapely.from_geojson(output["geojson"])
|
79 |
expected_output_geojson = shapely.from_geojson(input_output["output"]["geojson"])
|
80 |
assert shapely.equals_exact(output_geojson, expected_output_geojson, tolerance=0.000006)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
"check https://github.com/rasterio/affine project for updates")
|
63 |
raise e
|
64 |
|
65 |
+
def test_get_vectorized_raster_as_geojson_ok(self):
|
66 |
from src.io.geo_helpers import get_vectorized_raster_as_geojson
|
67 |
|
68 |
name_fn = "samexporter_predict"
|
|
|
78 |
output_geojson = shapely.from_geojson(output["geojson"])
|
79 |
expected_output_geojson = shapely.from_geojson(input_output["output"]["geojson"])
|
80 |
assert shapely.equals_exact(output_geojson, expected_output_geojson, tolerance=0.000006)
|
81 |
+
|
82 |
+
def test_get_vectorized_raster_as_geojson_fail(self):
|
83 |
+
from src.io.geo_helpers import get_vectorized_raster_as_geojson
|
84 |
+
|
85 |
+
name_fn = "samexporter_predict"
|
86 |
+
|
87 |
+
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
88 |
+
inputs_outputs = json.load(tst_json)
|
89 |
+
for k, input_output in inputs_outputs.items():
|
90 |
+
print(f"k:{k}.")
|
91 |
+
mask = np.load(TEST_EVENTS_FOLDER / name_fn / k / "mask.npy")
|
92 |
+
|
93 |
+
# Could be also another generic Exception, here we intercept TypeError caused by wrong matrix input on
|
94 |
+
# rasterio.Affine.from_gdal() wrapped by get_affine_transform_from_gdal()
|
95 |
+
with self.assertRaises(TypeError):
|
96 |
+
try:
|
97 |
+
wrong_matrix = 1.0,
|
98 |
+
get_vectorized_raster_as_geojson(mask=mask, matrix=wrong_matrix)
|
99 |
+
except TypeError as te:
|
100 |
+
print(f"te:{te}.")
|
101 |
+
msg = "Affine.from_gdal() missing 5 required positional arguments: 'a', 'b', 'f', 'd', and 'e'"
|
102 |
+
self.assertEqual(str(te), msg)
|
103 |
+
raise te
|
tests/prediction_api/test_predictors.py
CHANGED
@@ -2,16 +2,15 @@ import json
|
|
2 |
from unittest.mock import patch
|
3 |
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
-
from src.prediction_api import
|
7 |
-
from src.prediction_api.predictors import get_raster_inference
|
8 |
from tests import TEST_EVENTS_FOLDER
|
9 |
|
10 |
|
11 |
-
@patch.object(
|
12 |
-
def test_get_raster_inference(
|
13 |
-
segment_anything_onnx_mocked
|
14 |
-
):
|
15 |
name_fn = "samexporter_predict"
|
16 |
|
17 |
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
@@ -38,3 +37,32 @@ def test_get_raster_inference(
|
|
38 |
)
|
39 |
assert np.array_equal(output_mask, mask)
|
40 |
assert len_inference_out == input_output["output"]["n_predictions"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from unittest.mock import patch
|
3 |
|
4 |
import numpy as np
|
5 |
+
import rasterio
|
6 |
|
7 |
+
from src.prediction_api import predictors
|
8 |
+
from src.prediction_api.predictors import get_raster_inference, samexporter_predict
|
9 |
from tests import TEST_EVENTS_FOLDER
|
10 |
|
11 |
|
12 |
+
@patch.object(predictors, "SegmentAnythingONNX")
|
13 |
+
def test_get_raster_inference(segment_anything_onnx_mocked):
|
|
|
|
|
14 |
name_fn = "samexporter_predict"
|
15 |
|
16 |
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
|
|
37 |
)
|
38 |
assert np.array_equal(output_mask, mask)
|
39 |
assert len_inference_out == input_output["output"]["n_predictions"]
|
40 |
+
|
41 |
+
|
42 |
+
@patch.object(predictors, "get_raster_inference")
|
43 |
+
@patch.object(predictors, "SegmentAnythingONNX")
|
44 |
+
@patch.object(predictors, "download_extent")
|
45 |
+
@patch.object(predictors, "get_vectorized_raster_as_geojson")
|
46 |
+
@patch.object(predictors, "get_affine_transform_from_gdal")
|
47 |
+
def test_samexporter_predict(
|
48 |
+
get_affine_transform_from_gdal_mocked,
|
49 |
+
get_vectorized_raster_as_geojson_mocked,
|
50 |
+
download_extent_mocked,
|
51 |
+
segment_anything_onnx_mocked,
|
52 |
+
get_raster_inference_mocked
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
model_instance = SegmentAnythingONNX()
|
56 |
+
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
|
57 |
+
transform = get_affine_transform_from_gdal(matrix)
|
58 |
+
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
|
59 |
+
get_vectorized_raster_as_geojson(mask, matrix)
|
60 |
+
"""
|
61 |
+
aff = 1, 2, 3, 4, 5, 6
|
62 |
+
segment_anything_onnx_mocked.return_value = "SegmentAnythingONNX_instance"
|
63 |
+
download_extent_mocked.return_value = np.zeros((10, 10)), aff
|
64 |
+
get_affine_transform_from_gdal_mocked.return_value = rasterio.Affine.from_gdal(*aff)
|
65 |
+
get_raster_inference_mocked.return_value = np.ones((10, 10)), 1
|
66 |
+
get_vectorized_raster_as_geojson_mocked.return_value = {"geojson": "{}", "n_shapes_geojson": 2}
|
67 |
+
output = samexporter_predict(bbox=[[1, 2], [3, 4]], prompt=[{}], zoom=10, model_name="fastsam")
|
68 |
+
assert output == {"n_predictions": 1, "geojson": "{}", "n_shapes_geojson": 2}
|