[debug] first test on lambda for fastsam prediction
Browse files- README.md +11 -0
- dockerfiles/dockerfile-lambda-gdal-runner +6 -3
- dockerfiles/dockerfile-lambda-samgeo-api +4 -0
- models/.gitignore +0 -0
- requirements.txt +5 -4
- requirements_dev.txt +7 -5
- src/__init__.py +4 -0
- src/app.py +16 -6
- src/prediction_api/predictors.py +82 -0
- src/prediction_api/sam_onnx.py +203 -0
- src/prediction_api/samgeo_predictors.py +0 -56
- src/utilities/constants.py +2 -1
- src/utilities/serialize.py +85 -0
README.md
CHANGED
@@ -1,5 +1,16 @@
|
|
1 |
# Segment Geospatial
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
Build the docker image:
|
4 |
|
5 |
```bash
|
|
|
1 |
# Segment Geospatial
|
2 |
|
3 |
+
## todo
|
4 |
+
|
5 |
+
1. export output to mask: OK local, OK aws lambda
|
6 |
+
2. resolve model paths: OK local
|
7 |
+
3. inference:
|
8 |
+
4. from mask to json (rasterio + geopandas, check for re-projection to EPSG_4326)
|
9 |
+
5. check mandatory dependencies
|
10 |
+
6. check for alternative python interpreters
|
11 |
+
|
12 |
+
## Build instructions
|
13 |
+
|
14 |
Build the docker image:
|
15 |
|
16 |
```bash
|
dockerfiles/dockerfile-lambda-gdal-runner
CHANGED
@@ -9,12 +9,15 @@ RUN echo "ENV RIE: $RIE ..."
|
|
9 |
|
10 |
# Set working directory to function root directory
|
11 |
WORKDIR ${LAMBDA_TASK_ROOT}
|
12 |
-
COPY
|
13 |
|
14 |
-
|
|
|
|
|
15 |
RUN which python
|
16 |
RUN python --version
|
17 |
-
RUN python -m pip install
|
|
|
18 |
|
19 |
RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
|
20 |
RUN chmod +x /usr/local/bin/aws-lambda-rie
|
|
|
9 |
|
10 |
# Set working directory to function root directory
|
11 |
WORKDIR ${LAMBDA_TASK_ROOT}
|
12 |
+
COPY requirements_dev.txt ${LAMBDA_TASK_ROOT}/requirements_dev.txt
|
13 |
|
14 |
+
# avoid segment-geospatial exception caused by missing libGL.so.1 library
|
15 |
+
RUN apt update && apt install -y libgl1 curl python3-pip
|
16 |
+
RUN ls -ld /usr/lib/*linux-gnu/libGL.so* || echo "libGL.so* not found..."
|
17 |
RUN which python
|
18 |
RUN python --version
|
19 |
+
RUN python -m pip install -r ${LAMBDA_TASK_ROOT}/requirements_dev.txt --target ${LAMBDA_TASK_ROOT}
|
20 |
+
# RUN python -m pip install pillow awslambdaric aws-lambda-powertools httpx jmespath --target ${LAMBDA_TASK_ROOT}
|
21 |
|
22 |
RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
|
23 |
RUN chmod +x /usr/local/bin/aws-lambda-rie
|
dockerfiles/dockerfile-lambda-samgeo-api
CHANGED
@@ -7,6 +7,7 @@ ARG PYTHONPATH="${LAMBDA_TASK_ROOT}:${PYTHONPATH}:/usr/local/lib/python3/dist-pa
|
|
7 |
# Set working directory to function root directory
|
8 |
WORKDIR ${LAMBDA_TASK_ROOT}
|
9 |
COPY ./src ${LAMBDA_TASK_ROOT}/src
|
|
|
10 |
|
11 |
RUN ls -l /usr/bin/which
|
12 |
RUN /usr/bin/which python
|
@@ -16,8 +17,11 @@ RUN echo "PATH: ${PATH}."
|
|
16 |
RUN echo "LAMBDA_TASK_ROOT: ${LAMBDA_TASK_ROOT}."
|
17 |
RUN ls -l ${LAMBDA_TASK_ROOT}
|
18 |
RUN ls -ld ${LAMBDA_TASK_ROOT}
|
|
|
19 |
RUN python -c "import sys; print(sys.path)"
|
20 |
RUN python -c "import osgeo"
|
|
|
|
|
21 |
# RUN python -c "import rasterio"
|
22 |
RUN python -c "import awslambdaric"
|
23 |
RUN python -m pip list
|
|
|
7 |
# Set working directory to function root directory
|
8 |
WORKDIR ${LAMBDA_TASK_ROOT}
|
9 |
COPY ./src ${LAMBDA_TASK_ROOT}/src
|
10 |
+
COPY ./models ${LAMBDA_TASK_ROOT}/models
|
11 |
|
12 |
RUN ls -l /usr/bin/which
|
13 |
RUN /usr/bin/which python
|
|
|
17 |
RUN echo "LAMBDA_TASK_ROOT: ${LAMBDA_TASK_ROOT}."
|
18 |
RUN ls -l ${LAMBDA_TASK_ROOT}
|
19 |
RUN ls -ld ${LAMBDA_TASK_ROOT}
|
20 |
+
RUN ls -l ${LAMBDA_TASK_ROOT}/models
|
21 |
RUN python -c "import sys; print(sys.path)"
|
22 |
RUN python -c "import osgeo"
|
23 |
+
RUN python -c "import cv2"
|
24 |
+
RUN python -c "import onnxruntime"
|
25 |
# RUN python -c "import rasterio"
|
26 |
RUN python -c "import awslambdaric"
|
27 |
RUN python -m pip list
|
models/.gitignore
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
aws-lambda-powertools
|
2 |
awslambdaric
|
3 |
bson
|
4 |
-
|
5 |
jmespath
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
1 |
aws-lambda-powertools
|
2 |
awslambdaric
|
3 |
bson
|
4 |
+
httpx
|
5 |
jmespath
|
6 |
+
numpy
|
7 |
+
onnxruntime
|
8 |
+
opencv-python
|
9 |
+
pillow
|
requirements_dev.txt
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
awslambdaric
|
2 |
-
|
3 |
-
|
4 |
-
geojson-pydantic
|
5 |
jmespath
|
6 |
-
|
7 |
-
|
|
|
|
|
|
1 |
+
aws-lambda-powertools
|
2 |
awslambdaric
|
3 |
+
bson
|
4 |
+
httpx
|
|
|
5 |
jmespath
|
6 |
+
numpy
|
7 |
+
onnxruntime
|
8 |
+
opencv-python
|
9 |
+
pillow
|
src/__init__.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
from aws_lambda_powertools import Logger
|
|
|
|
|
2 |
|
3 |
|
|
|
|
|
4 |
app_logger = Logger()
|
|
|
1 |
from aws_lambda_powertools import Logger
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
|
5 |
|
6 |
+
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
7 |
+
MODEL_FOLDER = Path(os.path.join(PROJECT_ROOT_FOLDER, "models"))
|
8 |
app_logger = Logger()
|
src/app.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import json
|
2 |
import time
|
3 |
from http import HTTPStatus
|
4 |
-
|
5 |
from aws_lambda_powertools.event_handler import content_types
|
6 |
from aws_lambda_powertools.utilities.typing import LambdaContext
|
7 |
|
8 |
from src import app_logger
|
9 |
-
from src.
|
10 |
-
from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES
|
11 |
|
12 |
|
13 |
-
def get_response(status: int, start_time: float, request_id: str, response_body = None) -> str:
|
14 |
"""
|
15 |
Return a response for frontend clients.
|
16 |
|
@@ -51,10 +51,20 @@ def lambda_handler(event: dict, context: LambdaContext):
|
|
51 |
app_logger.info(f"context:{context}...")
|
52 |
|
53 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
pt0 = 45.699, 127.1
|
55 |
pt1 = 30.1, 148.492
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
|
59 |
except Exception as ve:
|
60 |
app_logger.error(f"validation error:{ve}.")
|
|
|
1 |
import json
|
2 |
import time
|
3 |
from http import HTTPStatus
|
4 |
+
from typing import Dict
|
5 |
from aws_lambda_powertools.event_handler import content_types
|
6 |
from aws_lambda_powertools.utilities.typing import LambdaContext
|
7 |
|
8 |
from src import app_logger
|
9 |
+
from src.prediction_api.predictors import samexporter_predict
|
10 |
+
from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES
|
11 |
|
12 |
|
13 |
+
def get_response(status: int, start_time: float, request_id: str, response_body: Dict = None) -> str:
|
14 |
"""
|
15 |
Return a response for frontend clients.
|
16 |
|
|
|
51 |
app_logger.info(f"context:{context}...")
|
52 |
|
53 |
try:
|
54 |
+
"""
|
55 |
+
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], 6)
|
56 |
+
model_path = Path(MODEL_FOLDER) / "mobile_sam.encoder.onnx"
|
57 |
+
model_path_isfile = model_path.is_file()
|
58 |
+
model_path_stats = model_path.stat()
|
59 |
+
app_logger.info(f"model_path:{model_path_isfile}, {model_path_stats}.")
|
60 |
+
"""
|
61 |
pt0 = 45.699, 127.1
|
62 |
pt1 = 30.1, 148.492
|
63 |
+
bbox = [pt0, pt1]
|
64 |
+
zoom = 6
|
65 |
+
prompt = [{"type": "rectangle", "data": [400, 460, 524, 628]}]
|
66 |
+
body_response = {"geojson": samexporter_predict(bbox, prompt, zoom)}
|
67 |
+
app_logger.info(f"body_response::output:{body_response}.")
|
68 |
response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
|
69 |
except Exception as ve:
|
70 |
app_logger.error(f"validation error:{ve}.")
|
src/prediction_api/predictors.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Press the green button in the gutter to run the script.
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from src import app_logger, MODEL_FOLDER
|
6 |
+
from src.io.tms2geotiff import download_extent
|
7 |
+
from src.prediction_api.sam_onnx import SegmentAnythingONNX
|
8 |
+
from src.utilities.constants import ROOT, MODEL_ENCODER_NAME, ZOOM, SOURCE_TYPE, DEFAULT_TMS, MODEL_DECODER_NAME
|
9 |
+
from src.utilities.serialize import serialize
|
10 |
+
from src.utilities.type_hints import input_float_tuples
|
11 |
+
from src.utilities.utilities import get_system_info
|
12 |
+
|
13 |
+
|
14 |
+
def zip_arrays(arr1, arr2):
|
15 |
+
arr1_list = arr1.tolist()
|
16 |
+
arr2_list = arr2.tolist()
|
17 |
+
# return {serialize(k): serialize(v) for k, v in zip(arr1_list, arr2_list)}
|
18 |
+
d = {}
|
19 |
+
for n1, n2 in enumerate(zip(arr1_list, arr2_list)):
|
20 |
+
app_logger.info(f"n1:{n1}, type {type(n1)}, n2:{n2}, type {type(n2)}.")
|
21 |
+
n1f = str(n1)
|
22 |
+
n2f = str(n2)
|
23 |
+
app_logger.info(f"n1:{n1}=>{n1f}, n2:{n2}=>{n2f}.")
|
24 |
+
d[n1f] = n2f
|
25 |
+
app_logger.info(f"zipped dict:{d}.")
|
26 |
+
return d
|
27 |
+
|
28 |
+
|
29 |
+
def samexporter_predict(bbox: input_float_tuples, prompt: list[dict], zoom: float = ZOOM) -> dict:
|
30 |
+
import tempfile
|
31 |
+
|
32 |
+
try:
|
33 |
+
os.environ['MPLCONFIGDIR'] = ROOT
|
34 |
+
get_system_info()
|
35 |
+
except Exception as e:
|
36 |
+
app_logger.error(f"Error while setting 'MPLCONFIGDIR':{e}.")
|
37 |
+
|
38 |
+
with tempfile.NamedTemporaryFile(prefix=f"{SOURCE_TYPE}_", suffix=".tif", dir=ROOT) as image_input_tmp:
|
39 |
+
for coord in bbox:
|
40 |
+
app_logger.info(f"bbox coord:{coord}, type:{type(coord)}.")
|
41 |
+
app_logger.info(f"start download_extent using bbox:{bbox}, type:{type(bbox)}, download image...")
|
42 |
+
|
43 |
+
pt0 = bbox[0]
|
44 |
+
pt1 = bbox[1]
|
45 |
+
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
|
46 |
+
|
47 |
+
app_logger.info(f"img type {type(img)}, matrix type {type(matrix)}.")
|
48 |
+
np_img = np.array(img)
|
49 |
+
app_logger.info(f"np_img type {type(np_img)}.")
|
50 |
+
app_logger.info(f"np_img dtype {np_img.dtype}, shape {np_img.shape}.")
|
51 |
+
app_logger.info(f"geotiff created with size/shape {img.size} and transform matrix {str(matrix)}, start to initialize SamGeo instance:")
|
52 |
+
app_logger.info(f"use ENCODER model {MODEL_ENCODER_NAME} from {MODEL_FOLDER})...")
|
53 |
+
app_logger.info(f"use DECODER model {MODEL_DECODER_NAME} from {MODEL_FOLDER})...")
|
54 |
+
|
55 |
+
model = SegmentAnythingONNX(
|
56 |
+
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
|
57 |
+
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
|
58 |
+
)
|
59 |
+
app_logger.info(f"model instantiated, creating embedding...")
|
60 |
+
embedding = model.encode(np_img)
|
61 |
+
app_logger.info(f"embedding created, running predict_masks...")
|
62 |
+
prediction_masks = model.predict_masks(embedding, prompt)
|
63 |
+
app_logger.info(f"predict_masks terminated")
|
64 |
+
app_logger.info(f"prediction masks shape:{prediction_masks.shape}, {prediction_masks.dtype}.")
|
65 |
+
|
66 |
+
mask = np.zeros((prediction_masks.shape[2], prediction_masks.shape[3]), dtype=np.uint8)
|
67 |
+
for m in prediction_masks[0, :, :, :]:
|
68 |
+
mask[m > 0.0] = 255
|
69 |
+
|
70 |
+
mask_unique_values, mask_unique_values_count = serialize(np.unique(mask, return_counts=True))
|
71 |
+
app_logger.info(f"mask_unique_values:{mask_unique_values}.")
|
72 |
+
app_logger.info(f"mask_unique_values_count:{mask_unique_values_count}.")
|
73 |
+
|
74 |
+
output = {
|
75 |
+
"img_size": serialize(img.size),
|
76 |
+
"mask_unique_values_count": zip_arrays(mask_unique_values, mask_unique_values_count),
|
77 |
+
"masks_dtype": serialize(prediction_masks.dtype),
|
78 |
+
"masks_shape": serialize(prediction_masks.shape),
|
79 |
+
"matrix": serialize(matrix)
|
80 |
+
}
|
81 |
+
app_logger.info(f"output:{output}.")
|
82 |
+
return output
|
src/prediction_api/sam_onnx.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime
|
6 |
+
|
7 |
+
from src import app_logger
|
8 |
+
|
9 |
+
|
10 |
+
class SegmentAnythingONNX:
|
11 |
+
"""Segmentation model using SegmentAnything"""
|
12 |
+
|
13 |
+
def __init__(self, encoder_model_path, decoder_model_path) -> None:
|
14 |
+
self.target_size = 1024
|
15 |
+
self.input_size = (684, 1024)
|
16 |
+
|
17 |
+
# Load models
|
18 |
+
providers = onnxruntime.get_available_providers()
|
19 |
+
|
20 |
+
# Pop TensorRT Runtime due to crashing issues
|
21 |
+
# TODO: Add back when TensorRT backend is stable
|
22 |
+
providers = [p for p in providers if p != "TensorrtExecutionProvider"]
|
23 |
+
|
24 |
+
if providers:
|
25 |
+
app_logger.info(
|
26 |
+
"Available providers for ONNXRuntime: %s", ", ".join(providers)
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
app_logger.warning("No available providers for ONNXRuntime")
|
30 |
+
self.encoder_session = onnxruntime.InferenceSession(
|
31 |
+
encoder_model_path, providers=providers
|
32 |
+
)
|
33 |
+
self.encoder_input_name = self.encoder_session.get_inputs()[0].name
|
34 |
+
self.decoder_session = onnxruntime.InferenceSession(
|
35 |
+
decoder_model_path, providers=providers
|
36 |
+
)
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def get_input_points(prompt):
|
40 |
+
"""Get input points"""
|
41 |
+
points = []
|
42 |
+
labels = []
|
43 |
+
for mark in prompt:
|
44 |
+
if mark["type"] == "point":
|
45 |
+
points.append(mark["data"])
|
46 |
+
labels.append(mark["label"])
|
47 |
+
elif mark["type"] == "rectangle":
|
48 |
+
points.append([mark["data"][0], mark["data"][1]]) # top left
|
49 |
+
points.append(
|
50 |
+
[mark["data"][2], mark["data"][3]]
|
51 |
+
) # bottom right
|
52 |
+
labels.append(2)
|
53 |
+
labels.append(3)
|
54 |
+
points, labels = np.array(points), np.array(labels)
|
55 |
+
return points, labels
|
56 |
+
|
57 |
+
def run_encoder(self, encoder_inputs):
|
58 |
+
"""Run encoder"""
|
59 |
+
output = self.encoder_session.run(None, encoder_inputs)
|
60 |
+
image_embedding = output[0]
|
61 |
+
return image_embedding
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int):
|
65 |
+
"""
|
66 |
+
Compute the output size given input size and target long side length.
|
67 |
+
"""
|
68 |
+
scale = long_side_length * 1.0 / max(old_h, old_w)
|
69 |
+
new_h, new_w = old_h * scale, old_w * scale
|
70 |
+
new_w = int(new_w + 0.5)
|
71 |
+
new_h = int(new_h + 0.5)
|
72 |
+
return new_h, new_w
|
73 |
+
|
74 |
+
def apply_coords(self, coords: np.ndarray, original_size, target_length):
|
75 |
+
"""
|
76 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
77 |
+
original image size in (H, W) format.
|
78 |
+
"""
|
79 |
+
old_h, old_w = original_size
|
80 |
+
new_h, new_w = self.get_preprocess_shape(
|
81 |
+
original_size[0], original_size[1], target_length
|
82 |
+
)
|
83 |
+
coords = deepcopy(coords).astype(float)
|
84 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
85 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
86 |
+
return coords
|
87 |
+
|
88 |
+
def run_decoder(
|
89 |
+
self, image_embedding, original_size, transform_matrix, prompt
|
90 |
+
):
|
91 |
+
"""Run decoder"""
|
92 |
+
input_points, input_labels = self.get_input_points(prompt)
|
93 |
+
|
94 |
+
# Add a batch index, concatenate a padding point, and transform.
|
95 |
+
onnx_coord = np.concatenate(
|
96 |
+
[input_points, np.array([[0.0, 0.0]])], axis=0
|
97 |
+
)[None, :, :]
|
98 |
+
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
99 |
+
None, :
|
100 |
+
].astype(np.float32)
|
101 |
+
onnx_coord = self.apply_coords(
|
102 |
+
onnx_coord, self.input_size, self.target_size
|
103 |
+
).astype(np.float32)
|
104 |
+
|
105 |
+
# Apply the transformation matrix to the coordinates.
|
106 |
+
onnx_coord = np.concatenate(
|
107 |
+
[
|
108 |
+
onnx_coord,
|
109 |
+
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
|
110 |
+
],
|
111 |
+
axis=2,
|
112 |
+
)
|
113 |
+
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
|
114 |
+
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
|
115 |
+
|
116 |
+
# Create an empty mask input and an indicator for no mask.
|
117 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
118 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
119 |
+
|
120 |
+
decoder_inputs = {
|
121 |
+
"image_embeddings": image_embedding,
|
122 |
+
"point_coords": onnx_coord,
|
123 |
+
"point_labels": onnx_label,
|
124 |
+
"mask_input": onnx_mask_input,
|
125 |
+
"has_mask_input": onnx_has_mask_input,
|
126 |
+
"orig_im_size": np.array(self.input_size, dtype=np.float32),
|
127 |
+
}
|
128 |
+
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
|
129 |
+
|
130 |
+
# Transform the masks back to the original image size.
|
131 |
+
inv_transform_matrix = np.linalg.inv(transform_matrix)
|
132 |
+
transformed_masks = self.transform_masks(
|
133 |
+
masks, original_size, inv_transform_matrix
|
134 |
+
)
|
135 |
+
|
136 |
+
return transformed_masks
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def transform_masks(masks, original_size, transform_matrix):
|
140 |
+
"""Transform masks
|
141 |
+
Transform the masks back to the original image size.
|
142 |
+
"""
|
143 |
+
output_masks = []
|
144 |
+
for batch in range(masks.shape[0]):
|
145 |
+
batch_masks = []
|
146 |
+
for mask_id in range(masks.shape[1]):
|
147 |
+
mask = masks[batch, mask_id]
|
148 |
+
mask = cv2.warpAffine(
|
149 |
+
mask,
|
150 |
+
transform_matrix[:2],
|
151 |
+
(original_size[1], original_size[0]),
|
152 |
+
flags=cv2.INTER_LINEAR,
|
153 |
+
)
|
154 |
+
batch_masks.append(mask)
|
155 |
+
output_masks.append(batch_masks)
|
156 |
+
return np.array(output_masks)
|
157 |
+
|
158 |
+
def encode(self, cv_image):
|
159 |
+
"""
|
160 |
+
Calculate embedding and metadata for a single image.
|
161 |
+
"""
|
162 |
+
original_size = cv_image.shape[:2]
|
163 |
+
|
164 |
+
# Calculate a transformation matrix to convert to self.input_size
|
165 |
+
scale_x = self.input_size[1] / cv_image.shape[1]
|
166 |
+
scale_y = self.input_size[0] / cv_image.shape[0]
|
167 |
+
scale = min(scale_x, scale_y)
|
168 |
+
transform_matrix = np.array(
|
169 |
+
[
|
170 |
+
[scale, 0, 0],
|
171 |
+
[0, scale, 0],
|
172 |
+
[0, 0, 1],
|
173 |
+
]
|
174 |
+
)
|
175 |
+
cv_image = cv2.warpAffine(
|
176 |
+
cv_image,
|
177 |
+
transform_matrix[:2],
|
178 |
+
(self.input_size[1], self.input_size[0]),
|
179 |
+
flags=cv2.INTER_LINEAR,
|
180 |
+
)
|
181 |
+
|
182 |
+
encoder_inputs = {
|
183 |
+
self.encoder_input_name: cv_image.astype(np.float32),
|
184 |
+
}
|
185 |
+
image_embedding = self.run_encoder(encoder_inputs)
|
186 |
+
return {
|
187 |
+
"image_embedding": image_embedding,
|
188 |
+
"original_size": original_size,
|
189 |
+
"transform_matrix": transform_matrix,
|
190 |
+
}
|
191 |
+
|
192 |
+
def predict_masks(self, embedding, prompt):
|
193 |
+
"""
|
194 |
+
Predict masks for a single image.
|
195 |
+
"""
|
196 |
+
masks = self.run_decoder(
|
197 |
+
embedding["image_embedding"],
|
198 |
+
embedding["original_size"],
|
199 |
+
embedding["transform_matrix"],
|
200 |
+
prompt,
|
201 |
+
)
|
202 |
+
|
203 |
+
return masks
|
src/prediction_api/samgeo_predictors.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
# Press the green button in the gutter to run the script.
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
|
5 |
-
from src import app_logger
|
6 |
-
from src.utilities.constants import ROOT, MODEL_NAME, ZOOM, SOURCE_TYPE
|
7 |
-
from src.utilities.type_hints import input_floatlist
|
8 |
-
from src.utilities.utilities import get_system_info
|
9 |
-
|
10 |
-
|
11 |
-
def samgeo_fast_predict(
|
12 |
-
bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT, source_type: str = SOURCE_TYPE, crs="EPSG:4326"
|
13 |
-
) -> dict:
|
14 |
-
import tempfile
|
15 |
-
from samgeo import tms_to_geotiff
|
16 |
-
from samgeo.fast_sam import SamGeo
|
17 |
-
|
18 |
-
try:
|
19 |
-
os.environ['MPLCONFIGDIR'] = root_folder
|
20 |
-
get_system_info()
|
21 |
-
except Exception as e:
|
22 |
-
app_logger.error(f"Error while setting 'MPLCONFIGDIR':{e}.")
|
23 |
-
|
24 |
-
with tempfile.NamedTemporaryFile(prefix=f"{source_type}_", suffix=".tif", dir=root_folder) as image_input_tmp:
|
25 |
-
app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
|
26 |
-
for coord in bbox:
|
27 |
-
app_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
28 |
-
|
29 |
-
# bbox: image input coordinate
|
30 |
-
tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source=source_type, overwrite=True, crs=crs)
|
31 |
-
app_logger.info(f"geotiff created, start to initialize SamGeo instance (read model {model_name} from {root_folder})...")
|
32 |
-
|
33 |
-
predictor = SamGeo(
|
34 |
-
model_type=model_name,
|
35 |
-
checkpoint_dir=root_folder,
|
36 |
-
automatic=False,
|
37 |
-
sam_kwargs=None,
|
38 |
-
)
|
39 |
-
app_logger.info(f"initialized SamGeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
|
40 |
-
predictor.set_image(image_input_tmp.name)
|
41 |
-
|
42 |
-
with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
|
43 |
-
app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
|
44 |
-
predictor.everything_prompt(output=image_output_tmp.name)
|
45 |
-
|
46 |
-
# geotiff to geojson
|
47 |
-
with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
|
48 |
-
app_logger.info(f"done prediction, start conversion SamGeo.tiff_to_geojson({image_output_tmp.name}) => {vector_tmp.name}.")
|
49 |
-
predictor.tiff_to_geojson(image_output_tmp.name, vector_tmp.name, bidx=1)
|
50 |
-
|
51 |
-
app_logger.info(f"start reading geojson {vector_tmp.name}...")
|
52 |
-
with open(vector_tmp.name) as out_gdf:
|
53 |
-
out_gdf_str = out_gdf.read()
|
54 |
-
out_gdf_json = json.loads(out_gdf_str)
|
55 |
-
app_logger.info(f"geojson {vector_tmp.name} string has length: {len(out_gdf_str)}.")
|
56 |
-
return out_gdf_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utilities/constants.py
CHANGED
@@ -22,7 +22,8 @@ CUSTOM_RESPONSE_MESSAGES = {
|
|
22 |
422: "Missing required parameter",
|
23 |
500: "Internal server error"
|
24 |
}
|
25 |
-
|
|
|
26 |
ZOOM = 13
|
27 |
SOURCE_TYPE = "Satellite"
|
28 |
|
|
|
22 |
422: "Missing required parameter",
|
23 |
500: "Internal server error"
|
24 |
}
|
25 |
+
MODEL_ENCODER_NAME = "mobile_sam.encoder.onnx"
|
26 |
+
MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
|
27 |
ZOOM = 13
|
28 |
SOURCE_TYPE = "Satellite"
|
29 |
|
src/utilities/serialize.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Serialize objects"""
|
2 |
+
from typing import Mapping
|
3 |
+
|
4 |
+
from src import app_logger
|
5 |
+
from src.utilities.type_hints import ts_dict_str2, ts_dict_str3
|
6 |
+
|
7 |
+
|
8 |
+
def serialize(obj: any, include_none: bool = False) -> object:
|
9 |
+
"""
|
10 |
+
Return the input object into a serializable one
|
11 |
+
|
12 |
+
Args:
|
13 |
+
obj: Object to serialize
|
14 |
+
include_none: bool to indicate if include also keys with None values during dict serialization
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
object: serialized object
|
18 |
+
|
19 |
+
"""
|
20 |
+
return _serialize(obj, include_none)
|
21 |
+
|
22 |
+
|
23 |
+
def _serialize(obj: any, include_none: bool) -> any:
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
primitive = (int, float, str, bool)
|
27 |
+
# print(type(obj))
|
28 |
+
try:
|
29 |
+
if obj is None:
|
30 |
+
return None
|
31 |
+
elif isinstance(obj, np.integer):
|
32 |
+
return int(obj)
|
33 |
+
elif isinstance(obj, np.floating):
|
34 |
+
return float(obj)
|
35 |
+
elif isinstance(obj, np.ndarray):
|
36 |
+
return obj.tolist()
|
37 |
+
elif isinstance(obj, primitive):
|
38 |
+
return obj
|
39 |
+
elif type(obj) is list:
|
40 |
+
return _serialize_list(obj, include_none)
|
41 |
+
elif type(obj) is tuple:
|
42 |
+
return list(obj)
|
43 |
+
elif type(obj) is bytes:
|
44 |
+
return _serialize_bytes(obj)
|
45 |
+
elif isinstance(obj, Exception):
|
46 |
+
return _serialize_exception(obj)
|
47 |
+
# elif isinstance(obj, object):
|
48 |
+
# return _serialize_object(obj, include_none)
|
49 |
+
else:
|
50 |
+
return _serialize_object(obj, include_none)
|
51 |
+
except Exception as e_serialize:
|
52 |
+
app_logger.error(f"e_serialize::{e_serialize}, type_obj:{type(obj)}, obj:{obj}.")
|
53 |
+
return f"object_name:{str(obj)}__object_type_str:{str(type(obj))}."
|
54 |
+
|
55 |
+
|
56 |
+
def _serialize_object(obj: Mapping[any, object], include_none: bool) -> dict[any]:
|
57 |
+
from bson import ObjectId
|
58 |
+
|
59 |
+
res = {}
|
60 |
+
if type(obj) is not dict:
|
61 |
+
keys = [i for i in obj.__dict__.keys() if (getattr(obj, i) is not None) or include_none]
|
62 |
+
else:
|
63 |
+
keys = [i for i in obj.keys() if (obj[i] is not None) or include_none]
|
64 |
+
for key in keys:
|
65 |
+
if type(obj) is not dict:
|
66 |
+
res[key] = _serialize(getattr(obj, key), include_none)
|
67 |
+
elif isinstance(obj[key], ObjectId):
|
68 |
+
continue
|
69 |
+
else:
|
70 |
+
res[key] = _serialize(obj[key], include_none)
|
71 |
+
return res
|
72 |
+
|
73 |
+
|
74 |
+
def _serialize_list(ls: list, include_none: bool) -> list:
|
75 |
+
return [_serialize(elem, include_none) for elem in ls]
|
76 |
+
|
77 |
+
|
78 |
+
def _serialize_bytes(b: bytes) -> ts_dict_str2:
|
79 |
+
import base64
|
80 |
+
encoded = base64.b64encode(b)
|
81 |
+
return {"value": encoded.decode('ascii'), "type": "bytes"}
|
82 |
+
|
83 |
+
|
84 |
+
def _serialize_exception(e: Exception) -> ts_dict_str3:
|
85 |
+
return {"msg": str(e), "type": str(type(e)), **e.__dict__}
|