[refactor] don't import complete modules, use only "from" syntax
Browse files- src/__init__.py +0 -1
- src/app.py +2 -2
- src/io/coordinates_pixel_conversion.py +8 -6
- src/io/geo_helpers.py +2 -5
- src/io/lambda_helpers.py +12 -9
- src/prediction_api/predictors.py +3 -3
- src/prediction_api/sam_onnx.py +31 -32
- src/utilities/serialize.py +4 -4
- src/utilities/utilities.py +8 -8
src/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
"""Get machine learning predictions from geodata raster images"""
|
| 2 |
from aws_lambda_powertools import Logger
|
| 3 |
-
import os
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from src.utilities.constants import SERVICE_NAME
|
|
|
|
| 1 |
"""Get machine learning predictions from geodata raster images"""
|
| 2 |
from aws_lambda_powertools import Logger
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from src.utilities.constants import SERVICE_NAME
|
src/app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
"""Lambda entry point"""
|
| 2 |
-
import time
|
| 3 |
from http import HTTPStatus
|
| 4 |
from typing import Dict
|
| 5 |
|
|
@@ -24,8 +23,9 @@ def lambda_handler(event: Dict, context: LambdaContext) -> str:
|
|
| 24 |
json response from get_response() function
|
| 25 |
|
| 26 |
"""
|
|
|
|
| 27 |
app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
|
| 28 |
-
start_time = time
|
| 29 |
|
| 30 |
if "version" in event:
|
| 31 |
app_logger.info(f"event version: {event['version']}.")
|
|
|
|
| 1 |
"""Lambda entry point"""
|
|
|
|
| 2 |
from http import HTTPStatus
|
| 3 |
from typing import Dict
|
| 4 |
|
|
|
|
| 23 |
json response from get_response() function
|
| 24 |
|
| 25 |
"""
|
| 26 |
+
from time import time
|
| 27 |
app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
|
| 28 |
+
start_time = time()
|
| 29 |
|
| 30 |
if "version" in event:
|
| 31 |
app_logger.info(f"event version: {event['version']}.")
|
src/io/coordinates_pixel_conversion.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
"""functions useful to convert to/from latitude-longitude coordinates to pixel image coordinates"""
|
| 2 |
-
import math
|
| 3 |
-
|
| 4 |
from src import app_logger
|
| 5 |
from src.utilities.constants import TILE_SIZE
|
| 6 |
from src.utilities.type_hints import ImagePixelCoordinates
|
|
@@ -8,17 +6,19 @@ from src.utilities.type_hints import LatLngDict
|
|
| 8 |
|
| 9 |
|
| 10 |
def _get_latlng2pixel_projection(latlng: LatLngDict) -> ImagePixelCoordinates:
|
|
|
|
|
|
|
| 11 |
app_logger.debug(f"latlng: {type(latlng)}, value:{latlng}.")
|
| 12 |
app_logger.debug(f'latlng lat: {type(latlng.lat)}, value:{latlng.lat}.')
|
| 13 |
app_logger.debug(f'latlng lng: {type(latlng.lng)}, value:{latlng.lng}.')
|
| 14 |
try:
|
| 15 |
-
sin_y: float =
|
| 16 |
app_logger.debug(f"sin_y, #1:{sin_y}.")
|
| 17 |
sin_y = min(max(sin_y, -0.9999), 0.9999)
|
| 18 |
app_logger.debug(f"sin_y, #2:{sin_y}.")
|
| 19 |
x = TILE_SIZE * (0.5 + latlng.lng / 360)
|
| 20 |
app_logger.debug(f"x:{x}.")
|
| 21 |
-
y = TILE_SIZE * (0.5 -
|
| 22 |
app_logger.debug(f"y:{y}.")
|
| 23 |
|
| 24 |
return {"x": x, "y": y}
|
|
@@ -28,14 +28,16 @@ def _get_latlng2pixel_projection(latlng: LatLngDict) -> ImagePixelCoordinates:
|
|
| 28 |
|
| 29 |
|
| 30 |
def _get_point_latlng_to_pixel_coordinates(latlng: LatLngDict, zoom: int | float) -> ImagePixelCoordinates:
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
world_coordinate: ImagePixelCoordinates = _get_latlng2pixel_projection(latlng)
|
| 33 |
app_logger.debug(f"world_coordinate:{world_coordinate}.")
|
| 34 |
scale: int = pow(2, zoom)
|
| 35 |
app_logger.debug(f"scale:{scale}.")
|
| 36 |
return ImagePixelCoordinates(
|
| 37 |
-
x=
|
| 38 |
-
y=
|
| 39 |
)
|
| 40 |
except Exception as e_format_latlng_to_pixel_coordinates:
|
| 41 |
app_logger.error(f'format_latlng_to_pixel_coordinates:{e_format_latlng_to_pixel_coordinates}.')
|
|
|
|
| 1 |
"""functions useful to convert to/from latitude-longitude coordinates to pixel image coordinates"""
|
|
|
|
|
|
|
| 2 |
from src import app_logger
|
| 3 |
from src.utilities.constants import TILE_SIZE
|
| 4 |
from src.utilities.type_hints import ImagePixelCoordinates
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def _get_latlng2pixel_projection(latlng: LatLngDict) -> ImagePixelCoordinates:
|
| 9 |
+
from math import log, pi, sin
|
| 10 |
+
|
| 11 |
app_logger.debug(f"latlng: {type(latlng)}, value:{latlng}.")
|
| 12 |
app_logger.debug(f'latlng lat: {type(latlng.lat)}, value:{latlng.lat}.')
|
| 13 |
app_logger.debug(f'latlng lng: {type(latlng.lng)}, value:{latlng.lng}.')
|
| 14 |
try:
|
| 15 |
+
sin_y: float = sin(latlng.lat * pi / 180)
|
| 16 |
app_logger.debug(f"sin_y, #1:{sin_y}.")
|
| 17 |
sin_y = min(max(sin_y, -0.9999), 0.9999)
|
| 18 |
app_logger.debug(f"sin_y, #2:{sin_y}.")
|
| 19 |
x = TILE_SIZE * (0.5 + latlng.lng / 360)
|
| 20 |
app_logger.debug(f"x:{x}.")
|
| 21 |
+
y = TILE_SIZE * (0.5 - log((1 + sin_y) / (1 - sin_y)) / (4 * pi))
|
| 22 |
app_logger.debug(f"y:{y}.")
|
| 23 |
|
| 24 |
return {"x": x, "y": y}
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def _get_point_latlng_to_pixel_coordinates(latlng: LatLngDict, zoom: int | float) -> ImagePixelCoordinates:
|
| 31 |
+
from math import floor
|
| 32 |
+
|
| 33 |
try:
|
| 34 |
world_coordinate: ImagePixelCoordinates = _get_latlng2pixel_projection(latlng)
|
| 35 |
app_logger.debug(f"world_coordinate:{world_coordinate}.")
|
| 36 |
scale: int = pow(2, zoom)
|
| 37 |
app_logger.debug(f"scale:{scale}.")
|
| 38 |
return ImagePixelCoordinates(
|
| 39 |
+
x=floor(world_coordinate["x"] * scale),
|
| 40 |
+
y=floor(world_coordinate["y"] * scale)
|
| 41 |
)
|
| 42 |
except Exception as e_format_latlng_to_pixel_coordinates:
|
| 43 |
app_logger.error(f'format_latlng_to_pixel_coordinates:{e_format_latlng_to_pixel_coordinates}.')
|
src/io/geo_helpers.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
"""handle geo-referenced raster images"""
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
from affine import Affine
|
|
|
|
| 7 |
|
| 8 |
from src import app_logger
|
| 9 |
from src.utilities.type_hints import list_float, tuple_float, dict_str_int
|
|
@@ -45,7 +42,7 @@ def get_affine_transform_from_gdal(matrix_source_coefficients: list_float or tup
|
|
| 45 |
return Affine.from_gdal(*matrix_source_coefficients)
|
| 46 |
|
| 47 |
|
| 48 |
-
def get_vectorized_raster_as_geojson(mask:
|
| 49 |
"""
|
| 50 |
Get shapes and values of connected regions in a dataset or array
|
| 51 |
|
|
|
|
| 1 |
"""handle geo-referenced raster images"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from affine import Affine
|
| 3 |
+
from numpy import ndarray as np_ndarray
|
| 4 |
|
| 5 |
from src import app_logger
|
| 6 |
from src.utilities.type_hints import list_float, tuple_float, dict_str_int
|
|
|
|
| 42 |
return Affine.from_gdal(*matrix_source_coefficients)
|
| 43 |
|
| 44 |
|
| 45 |
+
def get_vectorized_raster_as_geojson(mask: np_ndarray, matrix: tuple_float) -> dict_str_int:
|
| 46 |
"""
|
| 47 |
Get shapes and values of connected regions in a dataset or array
|
| 48 |
|
src/io/lambda_helpers.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
"""lambda helper functions"""
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import time
|
| 5 |
from typing import Dict
|
| 6 |
from aws_lambda_powertools.event_handler import content_types
|
| 7 |
|
|
@@ -26,19 +23,22 @@ def get_response(status: int, start_time: float, request_id: str, response_body:
|
|
| 26 |
json response
|
| 27 |
|
| 28 |
"""
|
|
|
|
|
|
|
|
|
|
| 29 |
app_logger.debug(f"response_body:{response_body}.")
|
| 30 |
-
response_body["duration_run"] = time
|
| 31 |
response_body["message"] = CUSTOM_RESPONSE_MESSAGES[status]
|
| 32 |
response_body["request_id"] = request_id
|
| 33 |
|
| 34 |
response = {
|
| 35 |
"statusCode": status,
|
| 36 |
"header": {"Content-Type": content_types.APPLICATION_JSON},
|
| 37 |
-
"body":
|
| 38 |
"isBase64Encoded": False
|
| 39 |
}
|
| 40 |
app_logger.debug(f"response type:{type(response)} => {response}.")
|
| 41 |
-
return
|
| 42 |
|
| 43 |
|
| 44 |
def get_parsed_bbox_points(request_input: RawRequestInput) -> Dict:
|
|
@@ -98,7 +98,10 @@ def get_parsed_request_body(event: Dict) -> RawRequestInput:
|
|
| 98 |
Returns:
|
| 99 |
parsed request input
|
| 100 |
"""
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
try:
|
| 103 |
raw_body = event["body"]
|
| 104 |
except Exception as e_constants1:
|
|
@@ -108,12 +111,12 @@ def get_parsed_request_body(event: Dict) -> RawRequestInput:
|
|
| 108 |
if isinstance(raw_body, str):
|
| 109 |
body_decoded_str = base64_decode(raw_body)
|
| 110 |
app_logger.debug(f"body_decoded_str: {type(body_decoded_str)}, {body_decoded_str}...")
|
| 111 |
-
raw_body =
|
| 112 |
app_logger.info(f"body, #2: {type(raw_body)}, {raw_body}...")
|
| 113 |
|
| 114 |
parsed_body = RawRequestInput.model_validate(raw_body)
|
| 115 |
log_level = "DEBUG" if parsed_body.debug else "INFO"
|
| 116 |
app_logger.setLevel(log_level)
|
| 117 |
-
app_logger.warning(f"set log level to {
|
| 118 |
|
| 119 |
return parsed_body
|
|
|
|
| 1 |
"""lambda helper functions"""
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Dict
|
| 3 |
from aws_lambda_powertools.event_handler import content_types
|
| 4 |
|
|
|
|
| 23 |
json response
|
| 24 |
|
| 25 |
"""
|
| 26 |
+
from json import dumps
|
| 27 |
+
from time import time
|
| 28 |
+
|
| 29 |
app_logger.debug(f"response_body:{response_body}.")
|
| 30 |
+
response_body["duration_run"] = time() - start_time
|
| 31 |
response_body["message"] = CUSTOM_RESPONSE_MESSAGES[status]
|
| 32 |
response_body["request_id"] = request_id
|
| 33 |
|
| 34 |
response = {
|
| 35 |
"statusCode": status,
|
| 36 |
"header": {"Content-Type": content_types.APPLICATION_JSON},
|
| 37 |
+
"body": dumps(response_body),
|
| 38 |
"isBase64Encoded": False
|
| 39 |
}
|
| 40 |
app_logger.debug(f"response type:{type(response)} => {response}.")
|
| 41 |
+
return dumps(response)
|
| 42 |
|
| 43 |
|
| 44 |
def get_parsed_bbox_points(request_input: RawRequestInput) -> Dict:
|
|
|
|
| 98 |
Returns:
|
| 99 |
parsed request input
|
| 100 |
"""
|
| 101 |
+
from json import dumps, loads
|
| 102 |
+
from logging import getLevelName
|
| 103 |
+
|
| 104 |
+
app_logger.info(f"event:{dumps(event)}...")
|
| 105 |
try:
|
| 106 |
raw_body = event["body"]
|
| 107 |
except Exception as e_constants1:
|
|
|
|
| 111 |
if isinstance(raw_body, str):
|
| 112 |
body_decoded_str = base64_decode(raw_body)
|
| 113 |
app_logger.debug(f"body_decoded_str: {type(body_decoded_str)}, {body_decoded_str}...")
|
| 114 |
+
raw_body = loads(body_decoded_str)
|
| 115 |
app_logger.info(f"body, #2: {type(raw_body)}, {raw_body}...")
|
| 116 |
|
| 117 |
parsed_body = RawRequestInput.model_validate(raw_body)
|
| 118 |
log_level = "DEBUG" if parsed_body.debug else "INFO"
|
| 119 |
app_logger.setLevel(log_level)
|
| 120 |
+
app_logger.warning(f"set log level to {getLevelName(app_logger.log_level)}.")
|
| 121 |
|
| 122 |
return parsed_body
|
src/prediction_api/predictors.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""functions using machine learning instance model(s)"""
|
| 2 |
from PIL.Image import Image
|
| 3 |
-
|
| 4 |
|
| 5 |
from src import app_logger, MODEL_FOLDER
|
| 6 |
from src.io.geo_helpers import get_vectorized_raster_as_geojson, get_affine_transform_from_gdal
|
|
@@ -80,7 +80,7 @@ def get_raster_inference(
|
|
| 80 |
Returns:
|
| 81 |
raster prediction mask, prediction number
|
| 82 |
"""
|
| 83 |
-
np_img =
|
| 84 |
app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
|
| 85 |
app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
|
| 86 |
try:
|
|
@@ -95,7 +95,7 @@ def get_raster_inference(
|
|
| 95 |
len_inference_out = len(inference_out[0, :, :, :])
|
| 96 |
app_logger.info(f"Created {len_inference_out} prediction_masks,"
|
| 97 |
f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
|
| 98 |
-
mask =
|
| 99 |
for n, m in enumerate(inference_out[0, :, :, :]):
|
| 100 |
app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
|
| 101 |
f" => mask shape:{mask.shape}, {mask.dtype}.")
|
|
|
|
| 1 |
"""functions using machine learning instance model(s)"""
|
| 2 |
from PIL.Image import Image
|
| 3 |
+
from numpy import array as np_array, uint8, zeros
|
| 4 |
|
| 5 |
from src import app_logger, MODEL_FOLDER
|
| 6 |
from src.io.geo_helpers import get_vectorized_raster_as_geojson, get_affine_transform_from_gdal
|
|
|
|
| 80 |
Returns:
|
| 81 |
raster prediction mask, prediction number
|
| 82 |
"""
|
| 83 |
+
np_img = np_array(img)
|
| 84 |
app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
|
| 85 |
app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
|
| 86 |
try:
|
|
|
|
| 95 |
len_inference_out = len(inference_out[0, :, :, :])
|
| 96 |
app_logger.info(f"Created {len_inference_out} prediction_masks,"
|
| 97 |
f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
|
| 98 |
+
mask = zeros((inference_out.shape[2], inference_out.shape[3]), dtype=uint8)
|
| 99 |
for n, m in enumerate(inference_out[0, :, :, :]):
|
| 100 |
app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
|
| 101 |
f" => mask shape:{mask.shape}, {mask.dtype}.")
|
src/prediction_api/sam_onnx.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Define a machine learning model executed by ONNX Runtime (https://
|
| 3 |
for Segment Anything (https://segment-anything.com).
|
| 4 |
Modified from https://github.com/vietanhdev/samexporter/
|
| 5 |
|
|
@@ -24,10 +24,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
| 24 |
SOFTWARE.
|
| 25 |
"""
|
| 26 |
from copy import deepcopy
|
| 27 |
-
|
| 28 |
-
import
|
| 29 |
-
import
|
| 30 |
-
import onnxruntime
|
| 31 |
|
| 32 |
from src import app_logger
|
| 33 |
|
|
@@ -40,7 +39,7 @@ class SegmentAnythingONNX:
|
|
| 40 |
self.input_size = (684, 1024)
|
| 41 |
|
| 42 |
# Load models
|
| 43 |
-
providers =
|
| 44 |
|
| 45 |
# Pop TensorRT Runtime due to crashing issues
|
| 46 |
# TODO: Add back when TensorRT backend is stable
|
|
@@ -52,11 +51,11 @@ class SegmentAnythingONNX:
|
|
| 52 |
)
|
| 53 |
else:
|
| 54 |
app_logger.warning("No available providers for ONNXRuntime")
|
| 55 |
-
self.encoder_session =
|
| 56 |
encoder_model_path, providers=providers
|
| 57 |
)
|
| 58 |
self.encoder_input_name = self.encoder_session.get_inputs()[0].name
|
| 59 |
-
self.decoder_session =
|
| 60 |
decoder_model_path, providers=providers
|
| 61 |
)
|
| 62 |
|
|
@@ -76,7 +75,7 @@ class SegmentAnythingONNX:
|
|
| 76 |
) # bottom right
|
| 77 |
labels.append(2)
|
| 78 |
labels.append(3)
|
| 79 |
-
points, labels =
|
| 80 |
return points, labels
|
| 81 |
|
| 82 |
def run_encoder(self, encoder_inputs):
|
|
@@ -96,9 +95,9 @@ class SegmentAnythingONNX:
|
|
| 96 |
new_h = int(new_h + 0.5)
|
| 97 |
return new_h, new_w
|
| 98 |
|
| 99 |
-
def apply_coords(self, coords:
|
| 100 |
"""
|
| 101 |
-
Expects a numpy
|
| 102 |
original image size in (H, W) format.
|
| 103 |
"""
|
| 104 |
old_h, old_w = original_size
|
|
@@ -117,30 +116,30 @@ class SegmentAnythingONNX:
|
|
| 117 |
input_points, input_labels = self.get_input_points(prompt)
|
| 118 |
|
| 119 |
# Add a batch index, concatenate a padding point, and transform.
|
| 120 |
-
onnx_coord =
|
| 121 |
-
[input_points,
|
| 122 |
)[None, :, :]
|
| 123 |
-
onnx_label =
|
| 124 |
None, :
|
| 125 |
-
].astype(
|
| 126 |
onnx_coord = self.apply_coords(
|
| 127 |
onnx_coord, self.input_size, self.target_size
|
| 128 |
-
).astype(
|
| 129 |
|
| 130 |
# Apply the transformation matrix to the coordinates.
|
| 131 |
-
onnx_coord =
|
| 132 |
[
|
| 133 |
onnx_coord,
|
| 134 |
-
|
| 135 |
],
|
| 136 |
axis=2,
|
| 137 |
)
|
| 138 |
-
onnx_coord =
|
| 139 |
-
onnx_coord = onnx_coord[:, :, :2].astype(
|
| 140 |
|
| 141 |
# Create an empty mask input and an indicator for no mask.
|
| 142 |
-
onnx_mask_input =
|
| 143 |
-
onnx_has_mask_input =
|
| 144 |
|
| 145 |
decoder_inputs = {
|
| 146 |
"image_embeddings": image_embedding,
|
|
@@ -148,12 +147,12 @@ class SegmentAnythingONNX:
|
|
| 148 |
"point_labels": onnx_label,
|
| 149 |
"mask_input": onnx_mask_input,
|
| 150 |
"has_mask_input": onnx_has_mask_input,
|
| 151 |
-
"orig_im_size":
|
| 152 |
}
|
| 153 |
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
|
| 154 |
|
| 155 |
# Transform the masks back to the original image size.
|
| 156 |
-
inv_transform_matrix =
|
| 157 |
transformed_masks = self.transform_masks(
|
| 158 |
masks, original_size, inv_transform_matrix
|
| 159 |
)
|
|
@@ -175,11 +174,11 @@ class SegmentAnythingONNX:
|
|
| 175 |
app_logger.debug(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
|
| 176 |
except Exception as e_mask_shape_transform_masks:
|
| 177 |
app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
|
| 178 |
-
mask =
|
| 179 |
mask,
|
| 180 |
transform_matrix[:2],
|
| 181 |
(original_size[1], original_size[0]),
|
| 182 |
-
flags=
|
| 183 |
)
|
| 184 |
except Exception as e_warp_affine1:
|
| 185 |
app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.")
|
|
@@ -188,7 +187,7 @@ class SegmentAnythingONNX:
|
|
| 188 |
raise e_warp_affine1
|
| 189 |
batch_masks.append(mask)
|
| 190 |
output_masks.append(batch_masks)
|
| 191 |
-
return
|
| 192 |
|
| 193 |
def encode(self, cv_image):
|
| 194 |
"""
|
|
@@ -200,7 +199,7 @@ class SegmentAnythingONNX:
|
|
| 200 |
scale_x = self.input_size[1] / cv_image.shape[1]
|
| 201 |
scale_y = self.input_size[0] / cv_image.shape[0]
|
| 202 |
scale = min(scale_x, scale_y)
|
| 203 |
-
transform_matrix =
|
| 204 |
[
|
| 205 |
[scale, 0, 0],
|
| 206 |
[0, scale, 0],
|
|
@@ -208,22 +207,22 @@ class SegmentAnythingONNX:
|
|
| 208 |
]
|
| 209 |
)
|
| 210 |
try:
|
| 211 |
-
cv_image =
|
| 212 |
cv_image,
|
| 213 |
transform_matrix[:2],
|
| 214 |
(self.input_size[1], self.input_size[0]),
|
| 215 |
-
flags=
|
| 216 |
)
|
| 217 |
except Exception as e_warp_affine2:
|
| 218 |
app_logger.error(f"e_warp_affine2:{e_warp_affine2}.")
|
| 219 |
-
np_cv_image =
|
| 220 |
app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
|
| 221 |
app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}")
|
| 222 |
app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.")
|
| 223 |
raise e_warp_affine2
|
| 224 |
|
| 225 |
encoder_inputs = {
|
| 226 |
-
self.encoder_input_name: cv_image.astype(
|
| 227 |
}
|
| 228 |
image_embedding = self.run_encoder(encoder_inputs)
|
| 229 |
return {
|
|
|
|
| 1 |
"""
|
| 2 |
+
Define a machine learning model executed by ONNX Runtime (https://ai/)
|
| 3 |
for Segment Anything (https://segment-anything.com).
|
| 4 |
Modified from https://github.com/vietanhdev/samexporter/
|
| 5 |
|
|
|
|
| 24 |
SOFTWARE.
|
| 25 |
"""
|
| 26 |
from copy import deepcopy
|
| 27 |
+
from numpy import array as np_array, concatenate, float32, linalg, matmul, ndarray, ones, zeros
|
| 28 |
+
from cv2 import INTER_LINEAR, warpAffine
|
| 29 |
+
from onnxruntime import get_available_providers, InferenceSession
|
|
|
|
| 30 |
|
| 31 |
from src import app_logger
|
| 32 |
|
|
|
|
| 39 |
self.input_size = (684, 1024)
|
| 40 |
|
| 41 |
# Load models
|
| 42 |
+
providers = get_available_providers()
|
| 43 |
|
| 44 |
# Pop TensorRT Runtime due to crashing issues
|
| 45 |
# TODO: Add back when TensorRT backend is stable
|
|
|
|
| 51 |
)
|
| 52 |
else:
|
| 53 |
app_logger.warning("No available providers for ONNXRuntime")
|
| 54 |
+
self.encoder_session = InferenceSession(
|
| 55 |
encoder_model_path, providers=providers
|
| 56 |
)
|
| 57 |
self.encoder_input_name = self.encoder_session.get_inputs()[0].name
|
| 58 |
+
self.decoder_session = InferenceSession(
|
| 59 |
decoder_model_path, providers=providers
|
| 60 |
)
|
| 61 |
|
|
|
|
| 75 |
) # bottom right
|
| 76 |
labels.append(2)
|
| 77 |
labels.append(3)
|
| 78 |
+
points, labels = np_array(points), np_array(labels)
|
| 79 |
return points, labels
|
| 80 |
|
| 81 |
def run_encoder(self, encoder_inputs):
|
|
|
|
| 95 |
new_h = int(new_h + 0.5)
|
| 96 |
return new_h, new_w
|
| 97 |
|
| 98 |
+
def apply_coords(self, coords: ndarray, original_size, target_length):
|
| 99 |
"""
|
| 100 |
+
Expects a numpy np_array of length 2 in the final dimension. Requires the
|
| 101 |
original image size in (H, W) format.
|
| 102 |
"""
|
| 103 |
old_h, old_w = original_size
|
|
|
|
| 116 |
input_points, input_labels = self.get_input_points(prompt)
|
| 117 |
|
| 118 |
# Add a batch index, concatenate a padding point, and transform.
|
| 119 |
+
onnx_coord = concatenate(
|
| 120 |
+
[input_points, np_array([[0.0, 0.0]])], axis=0
|
| 121 |
)[None, :, :]
|
| 122 |
+
onnx_label = concatenate([input_labels, np_array([-1])], axis=0)[
|
| 123 |
None, :
|
| 124 |
+
].astype(float32)
|
| 125 |
onnx_coord = self.apply_coords(
|
| 126 |
onnx_coord, self.input_size, self.target_size
|
| 127 |
+
).astype(float32)
|
| 128 |
|
| 129 |
# Apply the transformation matrix to the coordinates.
|
| 130 |
+
onnx_coord = concatenate(
|
| 131 |
[
|
| 132 |
onnx_coord,
|
| 133 |
+
ones((1, onnx_coord.shape[1], 1), dtype=float32),
|
| 134 |
],
|
| 135 |
axis=2,
|
| 136 |
)
|
| 137 |
+
onnx_coord = matmul(onnx_coord, transform_matrix.T)
|
| 138 |
+
onnx_coord = onnx_coord[:, :, :2].astype(float32)
|
| 139 |
|
| 140 |
# Create an empty mask input and an indicator for no mask.
|
| 141 |
+
onnx_mask_input = zeros((1, 1, 256, 256), dtype=float32)
|
| 142 |
+
onnx_has_mask_input = zeros(1, dtype=float32)
|
| 143 |
|
| 144 |
decoder_inputs = {
|
| 145 |
"image_embeddings": image_embedding,
|
|
|
|
| 147 |
"point_labels": onnx_label,
|
| 148 |
"mask_input": onnx_mask_input,
|
| 149 |
"has_mask_input": onnx_has_mask_input,
|
| 150 |
+
"orig_im_size": np_array(self.input_size, dtype=float32),
|
| 151 |
}
|
| 152 |
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
|
| 153 |
|
| 154 |
# Transform the masks back to the original image size.
|
| 155 |
+
inv_transform_matrix = linalg.inv(transform_matrix)
|
| 156 |
transformed_masks = self.transform_masks(
|
| 157 |
masks, original_size, inv_transform_matrix
|
| 158 |
)
|
|
|
|
| 174 |
app_logger.debug(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
|
| 175 |
except Exception as e_mask_shape_transform_masks:
|
| 176 |
app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
|
| 177 |
+
mask = warpAffine(
|
| 178 |
mask,
|
| 179 |
transform_matrix[:2],
|
| 180 |
(original_size[1], original_size[0]),
|
| 181 |
+
flags=INTER_LINEAR,
|
| 182 |
)
|
| 183 |
except Exception as e_warp_affine1:
|
| 184 |
app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.")
|
|
|
|
| 187 |
raise e_warp_affine1
|
| 188 |
batch_masks.append(mask)
|
| 189 |
output_masks.append(batch_masks)
|
| 190 |
+
return np_array(output_masks)
|
| 191 |
|
| 192 |
def encode(self, cv_image):
|
| 193 |
"""
|
|
|
|
| 199 |
scale_x = self.input_size[1] / cv_image.shape[1]
|
| 200 |
scale_y = self.input_size[0] / cv_image.shape[0]
|
| 201 |
scale = min(scale_x, scale_y)
|
| 202 |
+
transform_matrix = np_array(
|
| 203 |
[
|
| 204 |
[scale, 0, 0],
|
| 205 |
[0, scale, 0],
|
|
|
|
| 207 |
]
|
| 208 |
)
|
| 209 |
try:
|
| 210 |
+
cv_image = warpAffine(
|
| 211 |
cv_image,
|
| 212 |
transform_matrix[:2],
|
| 213 |
(self.input_size[1], self.input_size[0]),
|
| 214 |
+
flags=INTER_LINEAR,
|
| 215 |
)
|
| 216 |
except Exception as e_warp_affine2:
|
| 217 |
app_logger.error(f"e_warp_affine2:{e_warp_affine2}.")
|
| 218 |
+
np_cv_image = np_array(cv_image)
|
| 219 |
app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
|
| 220 |
app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}")
|
| 221 |
app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.")
|
| 222 |
raise e_warp_affine2
|
| 223 |
|
| 224 |
encoder_inputs = {
|
| 225 |
+
self.encoder_input_name: cv_image.astype(float32),
|
| 226 |
}
|
| 227 |
image_embedding = self.run_encoder(encoder_inputs)
|
| 228 |
return {
|
src/utilities/serialize.py
CHANGED
|
@@ -20,18 +20,18 @@ def serialize(obj: any, include_none: bool = False):
|
|
| 20 |
|
| 21 |
|
| 22 |
def _serialize(obj: any, include_none: bool):
|
| 23 |
-
|
| 24 |
|
| 25 |
primitive = (int, float, str, bool)
|
| 26 |
# print(type(obj))
|
| 27 |
try:
|
| 28 |
if obj is None:
|
| 29 |
return None
|
| 30 |
-
elif isinstance(obj,
|
| 31 |
return int(obj)
|
| 32 |
-
elif isinstance(obj,
|
| 33 |
return float(obj)
|
| 34 |
-
elif isinstance(obj,
|
| 35 |
return obj.tolist()
|
| 36 |
elif isinstance(obj, primitive):
|
| 37 |
return obj
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def _serialize(obj: any, include_none: bool):
|
| 23 |
+
from numpy import ndarray as np_ndarray, floating as np_floating, integer as np_integer
|
| 24 |
|
| 25 |
primitive = (int, float, str, bool)
|
| 26 |
# print(type(obj))
|
| 27 |
try:
|
| 28 |
if obj is None:
|
| 29 |
return None
|
| 30 |
+
elif isinstance(obj, np_integer):
|
| 31 |
return int(obj)
|
| 32 |
+
elif isinstance(obj, np_floating):
|
| 33 |
return float(obj)
|
| 34 |
+
elif isinstance(obj, np_ndarray):
|
| 35 |
return obj.tolist()
|
| 36 |
elif isinstance(obj, primitive):
|
| 37 |
return obj
|
src/utilities/utilities.py
CHANGED
|
@@ -66,26 +66,26 @@ def hash_calculate(arr) -> str or bytes:
|
|
| 66 |
Returns:
|
| 67 |
computed hash from input variable
|
| 68 |
"""
|
| 69 |
-
import
|
| 70 |
-
import numpy as np
|
| 71 |
from base64 import b64encode
|
|
|
|
| 72 |
|
| 73 |
-
if isinstance(arr,
|
| 74 |
-
hash_fn =
|
| 75 |
elif isinstance(arr, dict):
|
| 76 |
import json
|
| 77 |
|
| 78 |
serialized = serialize(arr)
|
| 79 |
variable_to_hash = json.dumps(serialized, sort_keys=True).encode('utf-8')
|
| 80 |
-
hash_fn =
|
| 81 |
elif isinstance(arr, str):
|
| 82 |
try:
|
| 83 |
-
hash_fn =
|
| 84 |
except TypeError:
|
| 85 |
app_logger.warning(f"TypeError, re-try encoding arg:{arr},type:{type(arr)}.")
|
| 86 |
-
hash_fn =
|
| 87 |
elif isinstance(arr, bytes):
|
| 88 |
-
hash_fn =
|
| 89 |
else:
|
| 90 |
raise ValueError(f"variable 'arr':{arr} not yet handled.")
|
| 91 |
return b64encode(hash_fn.digest())
|
|
|
|
| 66 |
Returns:
|
| 67 |
computed hash from input variable
|
| 68 |
"""
|
| 69 |
+
from hashlib import sha256
|
|
|
|
| 70 |
from base64 import b64encode
|
| 71 |
+
from numpy import ndarray as np_ndarray
|
| 72 |
|
| 73 |
+
if isinstance(arr, np_ndarray):
|
| 74 |
+
hash_fn = sha256(arr.data)
|
| 75 |
elif isinstance(arr, dict):
|
| 76 |
import json
|
| 77 |
|
| 78 |
serialized = serialize(arr)
|
| 79 |
variable_to_hash = json.dumps(serialized, sort_keys=True).encode('utf-8')
|
| 80 |
+
hash_fn = sha256(variable_to_hash)
|
| 81 |
elif isinstance(arr, str):
|
| 82 |
try:
|
| 83 |
+
hash_fn = sha256(arr)
|
| 84 |
except TypeError:
|
| 85 |
app_logger.warning(f"TypeError, re-try encoding arg:{arr},type:{type(arr)}.")
|
| 86 |
+
hash_fn = sha256(arr.encode('utf-8'))
|
| 87 |
elif isinstance(arr, bytes):
|
| 88 |
+
hash_fn = sha256(arr)
|
| 89 |
else:
|
| 90 |
raise ValueError(f"variable 'arr':{arr} not yet handled.")
|
| 91 |
return b64encode(hash_fn.digest())
|