samgis / tests /prediction_api /test_predictors.py
aletrn's picture
[test] add test cases for get_vectorized_raster_as_geojson and SegmentAnythingONNX
84a883f
raw
history blame
3.06 kB
import json
from unittest.mock import patch
import numpy as np
import rasterio
from src.prediction_api import predictors
from src.prediction_api.predictors import get_raster_inference, samexporter_predict
from tests import TEST_EVENTS_FOLDER
@patch.object(predictors, "SegmentAnythingONNX")
def test_get_raster_inference(segment_anything_onnx_mocked):
name_fn = "samexporter_predict"
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
inputs_outputs = json.load(tst_json)
for k, input_output in inputs_outputs.items():
model_mocked = segment_anything_onnx_mocked()
img = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "img.npy")
inference_out = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "inference_out.npy")
mask = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "mask.npy")
prompt = input_output["input"]["prompt"]
model_name = input_output["input"]["model_name"]
model_mocked.embed.return_value = np.array(img)
model_mocked.embed.side_effect = None
model_mocked.predict_masks.return_value = inference_out
model_mocked.predict_masks.side_effect = None
print(f"k:{k}.")
output_mask, len_inference_out = get_raster_inference(
img=img,
prompt=prompt,
models_instance=model_mocked,
model_name=model_name
)
assert np.array_equal(output_mask, mask)
assert len_inference_out == input_output["output"]["n_predictions"]
@patch.object(predictors, "get_raster_inference")
@patch.object(predictors, "SegmentAnythingONNX")
@patch.object(predictors, "download_extent")
@patch.object(predictors, "get_vectorized_raster_as_geojson")
@patch.object(predictors, "get_affine_transform_from_gdal")
def test_samexporter_predict(
get_affine_transform_from_gdal_mocked,
get_vectorized_raster_as_geojson_mocked,
download_extent_mocked,
segment_anything_onnx_mocked,
get_raster_inference_mocked
):
"""
model_instance = SegmentAnythingONNX()
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
transform = get_affine_transform_from_gdal(matrix)
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
get_vectorized_raster_as_geojson(mask, matrix)
"""
aff = 1, 2, 3, 4, 5, 6
segment_anything_onnx_mocked.return_value = "SegmentAnythingONNX_instance"
download_extent_mocked.return_value = np.zeros((10, 10)), aff
get_affine_transform_from_gdal_mocked.return_value = rasterio.Affine.from_gdal(*aff)
get_raster_inference_mocked.return_value = np.ones((10, 10)), 1
get_vectorized_raster_as_geojson_mocked.return_value = {"geojson": "{}", "n_shapes_geojson": 2}
output = samexporter_predict(bbox=[[1, 2], [3, 4]], prompt=[{}], zoom=10, model_name="fastsam")
assert output == {"n_predictions": 1, "geojson": "{}", "n_shapes_geojson": 2}