aletrn commited on
Commit
fa76f5f
·
1 Parent(s): ddb71db

[debug] first test on lambda for fastsam prediction

Browse files
README.md CHANGED
@@ -1,5 +1,16 @@
1
  # Segment Geospatial
2
 
 
 
 
 
 
 
 
 
 
 
 
3
  Build the docker image:
4
 
5
  ```bash
 
1
  # Segment Geospatial
2
 
3
+ ## todo
4
+
5
+ 1. export output to mask: OK local, OK aws lambda
6
+ 2. resolve model paths: OK local
7
+ 3. inference:
8
+ 4. from mask to json (rasterio + geopandas, check for re-projection to EPSG_4326)
9
+ 5. check mandatory dependencies
10
+ 6. check for alternative python interpreters
11
+
12
+ ## Build instructions
13
+
14
  Build the docker image:
15
 
16
  ```bash
dockerfiles/dockerfile-lambda-gdal-runner CHANGED
@@ -9,12 +9,15 @@ RUN echo "ENV RIE: $RIE ..."
9
 
10
  # Set working directory to function root directory
11
  WORKDIR ${LAMBDA_TASK_ROOT}
12
- COPY requirements.txt ${LAMBDA_TASK_ROOT}/requirements.txt
13
 
14
- RUN apt update && apt install -y curl python3-pip
 
 
15
  RUN which python
16
  RUN python --version
17
- RUN python -m pip install pillow awslambdaric aws-lambda-powertools httpx jmespath
 
18
 
19
  RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
20
  RUN chmod +x /usr/local/bin/aws-lambda-rie
 
9
 
10
  # Set working directory to function root directory
11
  WORKDIR ${LAMBDA_TASK_ROOT}
12
+ COPY requirements_dev.txt ${LAMBDA_TASK_ROOT}/requirements_dev.txt
13
 
14
+ # avoid segment-geospatial exception caused by missing libGL.so.1 library
15
+ RUN apt update && apt install -y libgl1 curl python3-pip
16
+ RUN ls -ld /usr/lib/*linux-gnu/libGL.so* || echo "libGL.so* not found..."
17
  RUN which python
18
  RUN python --version
19
+ RUN python -m pip install -r ${LAMBDA_TASK_ROOT}/requirements_dev.txt --target ${LAMBDA_TASK_ROOT}
20
+ # RUN python -m pip install pillow awslambdaric aws-lambda-powertools httpx jmespath --target ${LAMBDA_TASK_ROOT}
21
 
22
  RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
23
  RUN chmod +x /usr/local/bin/aws-lambda-rie
dockerfiles/dockerfile-lambda-samgeo-api CHANGED
@@ -7,6 +7,7 @@ ARG PYTHONPATH="${LAMBDA_TASK_ROOT}:${PYTHONPATH}:/usr/local/lib/python3/dist-pa
7
  # Set working directory to function root directory
8
  WORKDIR ${LAMBDA_TASK_ROOT}
9
  COPY ./src ${LAMBDA_TASK_ROOT}/src
 
10
 
11
  RUN ls -l /usr/bin/which
12
  RUN /usr/bin/which python
@@ -16,8 +17,11 @@ RUN echo "PATH: ${PATH}."
16
  RUN echo "LAMBDA_TASK_ROOT: ${LAMBDA_TASK_ROOT}."
17
  RUN ls -l ${LAMBDA_TASK_ROOT}
18
  RUN ls -ld ${LAMBDA_TASK_ROOT}
 
19
  RUN python -c "import sys; print(sys.path)"
20
  RUN python -c "import osgeo"
 
 
21
  # RUN python -c "import rasterio"
22
  RUN python -c "import awslambdaric"
23
  RUN python -m pip list
 
7
  # Set working directory to function root directory
8
  WORKDIR ${LAMBDA_TASK_ROOT}
9
  COPY ./src ${LAMBDA_TASK_ROOT}/src
10
+ COPY ./models ${LAMBDA_TASK_ROOT}/models
11
 
12
  RUN ls -l /usr/bin/which
13
  RUN /usr/bin/which python
 
17
  RUN echo "LAMBDA_TASK_ROOT: ${LAMBDA_TASK_ROOT}."
18
  RUN ls -l ${LAMBDA_TASK_ROOT}
19
  RUN ls -ld ${LAMBDA_TASK_ROOT}
20
+ RUN ls -l ${LAMBDA_TASK_ROOT}/models
21
  RUN python -c "import sys; print(sys.path)"
22
  RUN python -c "import osgeo"
23
+ RUN python -c "import cv2"
24
+ RUN python -c "import onnxruntime"
25
  # RUN python -c "import rasterio"
26
  RUN python -c "import awslambdaric"
27
  RUN python -m pip list
models/.gitignore ADDED
File without changes
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  aws-lambda-powertools
2
  awslambdaric
3
  bson
4
- geojson-pydantic
5
  jmespath
6
- python-dotenv
7
- segment-anything-fast
8
- segment-geospatial
 
 
1
  aws-lambda-powertools
2
  awslambdaric
3
  bson
4
+ httpx
5
  jmespath
6
+ numpy
7
+ onnxruntime
8
+ opencv-python
9
+ pillow
requirements_dev.txt CHANGED
@@ -1,7 +1,9 @@
 
1
  awslambdaric
2
- aws_lambda_powertools
3
- fastjsonschema
4
- geojson-pydantic
5
  jmespath
6
- pydantic
7
- requests
 
 
 
1
+ aws-lambda-powertools
2
  awslambdaric
3
+ bson
4
+ httpx
 
5
  jmespath
6
+ numpy
7
+ onnxruntime
8
+ opencv-python
9
+ pillow
src/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
  from aws_lambda_powertools import Logger
 
 
2
 
3
 
 
 
4
  app_logger = Logger()
 
1
  from aws_lambda_powertools import Logger
2
+ import os
3
+ from pathlib import Path
4
 
5
 
6
+ PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
7
+ MODEL_FOLDER = Path(os.path.join(PROJECT_ROOT_FOLDER, "models"))
8
  app_logger = Logger()
src/app.py CHANGED
@@ -1,16 +1,16 @@
1
  import json
2
  import time
3
  from http import HTTPStatus
4
-
5
  from aws_lambda_powertools.event_handler import content_types
6
  from aws_lambda_powertools.utilities.typing import LambdaContext
7
 
8
  from src import app_logger
9
- from src.io.tms2geotiff import download_extent
10
- from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, DEFAULT_TMS
11
 
12
 
13
- def get_response(status: int, start_time: float, request_id: str, response_body = None) -> str:
14
  """
15
  Return a response for frontend clients.
16
 
@@ -51,10 +51,20 @@ def lambda_handler(event: dict, context: LambdaContext):
51
  app_logger.info(f"context:{context}...")
52
 
53
  try:
 
 
 
 
 
 
 
54
  pt0 = 45.699, 127.1
55
  pt1 = 30.1, 148.492
56
- img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], 6)
57
- body_response = {"geojson": {"img_size": img.size}}
 
 
 
58
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
59
  except Exception as ve:
60
  app_logger.error(f"validation error:{ve}.")
 
1
  import json
2
  import time
3
  from http import HTTPStatus
4
+ from typing import Dict
5
  from aws_lambda_powertools.event_handler import content_types
6
  from aws_lambda_powertools.utilities.typing import LambdaContext
7
 
8
  from src import app_logger
9
+ from src.prediction_api.predictors import samexporter_predict
10
+ from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES
11
 
12
 
13
+ def get_response(status: int, start_time: float, request_id: str, response_body: Dict = None) -> str:
14
  """
15
  Return a response for frontend clients.
16
 
 
51
  app_logger.info(f"context:{context}...")
52
 
53
  try:
54
+ """
55
+ img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], 6)
56
+ model_path = Path(MODEL_FOLDER) / "mobile_sam.encoder.onnx"
57
+ model_path_isfile = model_path.is_file()
58
+ model_path_stats = model_path.stat()
59
+ app_logger.info(f"model_path:{model_path_isfile}, {model_path_stats}.")
60
+ """
61
  pt0 = 45.699, 127.1
62
  pt1 = 30.1, 148.492
63
+ bbox = [pt0, pt1]
64
+ zoom = 6
65
+ prompt = [{"type": "rectangle", "data": [400, 460, 524, 628]}]
66
+ body_response = {"geojson": samexporter_predict(bbox, prompt, zoom)}
67
+ app_logger.info(f"body_response::output:{body_response}.")
68
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
69
  except Exception as ve:
70
  app_logger.error(f"validation error:{ve}.")
src/prediction_api/predictors.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Press the green button in the gutter to run the script.
2
+ import os
3
+ import numpy as np
4
+
5
+ from src import app_logger, MODEL_FOLDER
6
+ from src.io.tms2geotiff import download_extent
7
+ from src.prediction_api.sam_onnx import SegmentAnythingONNX
8
+ from src.utilities.constants import ROOT, MODEL_ENCODER_NAME, ZOOM, SOURCE_TYPE, DEFAULT_TMS, MODEL_DECODER_NAME
9
+ from src.utilities.serialize import serialize
10
+ from src.utilities.type_hints import input_float_tuples
11
+ from src.utilities.utilities import get_system_info
12
+
13
+
14
+ def zip_arrays(arr1, arr2):
15
+ arr1_list = arr1.tolist()
16
+ arr2_list = arr2.tolist()
17
+ # return {serialize(k): serialize(v) for k, v in zip(arr1_list, arr2_list)}
18
+ d = {}
19
+ for n1, n2 in enumerate(zip(arr1_list, arr2_list)):
20
+ app_logger.info(f"n1:{n1}, type {type(n1)}, n2:{n2}, type {type(n2)}.")
21
+ n1f = str(n1)
22
+ n2f = str(n2)
23
+ app_logger.info(f"n1:{n1}=>{n1f}, n2:{n2}=>{n2f}.")
24
+ d[n1f] = n2f
25
+ app_logger.info(f"zipped dict:{d}.")
26
+ return d
27
+
28
+
29
+ def samexporter_predict(bbox: input_float_tuples, prompt: list[dict], zoom: float = ZOOM) -> dict:
30
+ import tempfile
31
+
32
+ try:
33
+ os.environ['MPLCONFIGDIR'] = ROOT
34
+ get_system_info()
35
+ except Exception as e:
36
+ app_logger.error(f"Error while setting 'MPLCONFIGDIR':{e}.")
37
+
38
+ with tempfile.NamedTemporaryFile(prefix=f"{SOURCE_TYPE}_", suffix=".tif", dir=ROOT) as image_input_tmp:
39
+ for coord in bbox:
40
+ app_logger.info(f"bbox coord:{coord}, type:{type(coord)}.")
41
+ app_logger.info(f"start download_extent using bbox:{bbox}, type:{type(bbox)}, download image...")
42
+
43
+ pt0 = bbox[0]
44
+ pt1 = bbox[1]
45
+ img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
46
+
47
+ app_logger.info(f"img type {type(img)}, matrix type {type(matrix)}.")
48
+ np_img = np.array(img)
49
+ app_logger.info(f"np_img type {type(np_img)}.")
50
+ app_logger.info(f"np_img dtype {np_img.dtype}, shape {np_img.shape}.")
51
+ app_logger.info(f"geotiff created with size/shape {img.size} and transform matrix {str(matrix)}, start to initialize SamGeo instance:")
52
+ app_logger.info(f"use ENCODER model {MODEL_ENCODER_NAME} from {MODEL_FOLDER})...")
53
+ app_logger.info(f"use DECODER model {MODEL_DECODER_NAME} from {MODEL_FOLDER})...")
54
+
55
+ model = SegmentAnythingONNX(
56
+ encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
57
+ decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
58
+ )
59
+ app_logger.info(f"model instantiated, creating embedding...")
60
+ embedding = model.encode(np_img)
61
+ app_logger.info(f"embedding created, running predict_masks...")
62
+ prediction_masks = model.predict_masks(embedding, prompt)
63
+ app_logger.info(f"predict_masks terminated")
64
+ app_logger.info(f"prediction masks shape:{prediction_masks.shape}, {prediction_masks.dtype}.")
65
+
66
+ mask = np.zeros((prediction_masks.shape[2], prediction_masks.shape[3]), dtype=np.uint8)
67
+ for m in prediction_masks[0, :, :, :]:
68
+ mask[m > 0.0] = 255
69
+
70
+ mask_unique_values, mask_unique_values_count = serialize(np.unique(mask, return_counts=True))
71
+ app_logger.info(f"mask_unique_values:{mask_unique_values}.")
72
+ app_logger.info(f"mask_unique_values_count:{mask_unique_values_count}.")
73
+
74
+ output = {
75
+ "img_size": serialize(img.size),
76
+ "mask_unique_values_count": zip_arrays(mask_unique_values, mask_unique_values_count),
77
+ "masks_dtype": serialize(prediction_masks.dtype),
78
+ "masks_shape": serialize(prediction_masks.shape),
79
+ "matrix": serialize(matrix)
80
+ }
81
+ app_logger.info(f"output:{output}.")
82
+ return output
src/prediction_api/sam_onnx.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime
6
+
7
+ from src import app_logger
8
+
9
+
10
+ class SegmentAnythingONNX:
11
+ """Segmentation model using SegmentAnything"""
12
+
13
+ def __init__(self, encoder_model_path, decoder_model_path) -> None:
14
+ self.target_size = 1024
15
+ self.input_size = (684, 1024)
16
+
17
+ # Load models
18
+ providers = onnxruntime.get_available_providers()
19
+
20
+ # Pop TensorRT Runtime due to crashing issues
21
+ # TODO: Add back when TensorRT backend is stable
22
+ providers = [p for p in providers if p != "TensorrtExecutionProvider"]
23
+
24
+ if providers:
25
+ app_logger.info(
26
+ "Available providers for ONNXRuntime: %s", ", ".join(providers)
27
+ )
28
+ else:
29
+ app_logger.warning("No available providers for ONNXRuntime")
30
+ self.encoder_session = onnxruntime.InferenceSession(
31
+ encoder_model_path, providers=providers
32
+ )
33
+ self.encoder_input_name = self.encoder_session.get_inputs()[0].name
34
+ self.decoder_session = onnxruntime.InferenceSession(
35
+ decoder_model_path, providers=providers
36
+ )
37
+
38
+ @staticmethod
39
+ def get_input_points(prompt):
40
+ """Get input points"""
41
+ points = []
42
+ labels = []
43
+ for mark in prompt:
44
+ if mark["type"] == "point":
45
+ points.append(mark["data"])
46
+ labels.append(mark["label"])
47
+ elif mark["type"] == "rectangle":
48
+ points.append([mark["data"][0], mark["data"][1]]) # top left
49
+ points.append(
50
+ [mark["data"][2], mark["data"][3]]
51
+ ) # bottom right
52
+ labels.append(2)
53
+ labels.append(3)
54
+ points, labels = np.array(points), np.array(labels)
55
+ return points, labels
56
+
57
+ def run_encoder(self, encoder_inputs):
58
+ """Run encoder"""
59
+ output = self.encoder_session.run(None, encoder_inputs)
60
+ image_embedding = output[0]
61
+ return image_embedding
62
+
63
+ @staticmethod
64
+ def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int):
65
+ """
66
+ Compute the output size given input size and target long side length.
67
+ """
68
+ scale = long_side_length * 1.0 / max(old_h, old_w)
69
+ new_h, new_w = old_h * scale, old_w * scale
70
+ new_w = int(new_w + 0.5)
71
+ new_h = int(new_h + 0.5)
72
+ return new_h, new_w
73
+
74
+ def apply_coords(self, coords: np.ndarray, original_size, target_length):
75
+ """
76
+ Expects a numpy array of length 2 in the final dimension. Requires the
77
+ original image size in (H, W) format.
78
+ """
79
+ old_h, old_w = original_size
80
+ new_h, new_w = self.get_preprocess_shape(
81
+ original_size[0], original_size[1], target_length
82
+ )
83
+ coords = deepcopy(coords).astype(float)
84
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
85
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
86
+ return coords
87
+
88
+ def run_decoder(
89
+ self, image_embedding, original_size, transform_matrix, prompt
90
+ ):
91
+ """Run decoder"""
92
+ input_points, input_labels = self.get_input_points(prompt)
93
+
94
+ # Add a batch index, concatenate a padding point, and transform.
95
+ onnx_coord = np.concatenate(
96
+ [input_points, np.array([[0.0, 0.0]])], axis=0
97
+ )[None, :, :]
98
+ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
99
+ None, :
100
+ ].astype(np.float32)
101
+ onnx_coord = self.apply_coords(
102
+ onnx_coord, self.input_size, self.target_size
103
+ ).astype(np.float32)
104
+
105
+ # Apply the transformation matrix to the coordinates.
106
+ onnx_coord = np.concatenate(
107
+ [
108
+ onnx_coord,
109
+ np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
110
+ ],
111
+ axis=2,
112
+ )
113
+ onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
114
+ onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
115
+
116
+ # Create an empty mask input and an indicator for no mask.
117
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
118
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
119
+
120
+ decoder_inputs = {
121
+ "image_embeddings": image_embedding,
122
+ "point_coords": onnx_coord,
123
+ "point_labels": onnx_label,
124
+ "mask_input": onnx_mask_input,
125
+ "has_mask_input": onnx_has_mask_input,
126
+ "orig_im_size": np.array(self.input_size, dtype=np.float32),
127
+ }
128
+ masks, _, _ = self.decoder_session.run(None, decoder_inputs)
129
+
130
+ # Transform the masks back to the original image size.
131
+ inv_transform_matrix = np.linalg.inv(transform_matrix)
132
+ transformed_masks = self.transform_masks(
133
+ masks, original_size, inv_transform_matrix
134
+ )
135
+
136
+ return transformed_masks
137
+
138
+ @staticmethod
139
+ def transform_masks(masks, original_size, transform_matrix):
140
+ """Transform masks
141
+ Transform the masks back to the original image size.
142
+ """
143
+ output_masks = []
144
+ for batch in range(masks.shape[0]):
145
+ batch_masks = []
146
+ for mask_id in range(masks.shape[1]):
147
+ mask = masks[batch, mask_id]
148
+ mask = cv2.warpAffine(
149
+ mask,
150
+ transform_matrix[:2],
151
+ (original_size[1], original_size[0]),
152
+ flags=cv2.INTER_LINEAR,
153
+ )
154
+ batch_masks.append(mask)
155
+ output_masks.append(batch_masks)
156
+ return np.array(output_masks)
157
+
158
+ def encode(self, cv_image):
159
+ """
160
+ Calculate embedding and metadata for a single image.
161
+ """
162
+ original_size = cv_image.shape[:2]
163
+
164
+ # Calculate a transformation matrix to convert to self.input_size
165
+ scale_x = self.input_size[1] / cv_image.shape[1]
166
+ scale_y = self.input_size[0] / cv_image.shape[0]
167
+ scale = min(scale_x, scale_y)
168
+ transform_matrix = np.array(
169
+ [
170
+ [scale, 0, 0],
171
+ [0, scale, 0],
172
+ [0, 0, 1],
173
+ ]
174
+ )
175
+ cv_image = cv2.warpAffine(
176
+ cv_image,
177
+ transform_matrix[:2],
178
+ (self.input_size[1], self.input_size[0]),
179
+ flags=cv2.INTER_LINEAR,
180
+ )
181
+
182
+ encoder_inputs = {
183
+ self.encoder_input_name: cv_image.astype(np.float32),
184
+ }
185
+ image_embedding = self.run_encoder(encoder_inputs)
186
+ return {
187
+ "image_embedding": image_embedding,
188
+ "original_size": original_size,
189
+ "transform_matrix": transform_matrix,
190
+ }
191
+
192
+ def predict_masks(self, embedding, prompt):
193
+ """
194
+ Predict masks for a single image.
195
+ """
196
+ masks = self.run_decoder(
197
+ embedding["image_embedding"],
198
+ embedding["original_size"],
199
+ embedding["transform_matrix"],
200
+ prompt,
201
+ )
202
+
203
+ return masks
src/prediction_api/samgeo_predictors.py DELETED
@@ -1,56 +0,0 @@
1
- # Press the green button in the gutter to run the script.
2
- import json
3
- import os
4
-
5
- from src import app_logger
6
- from src.utilities.constants import ROOT, MODEL_NAME, ZOOM, SOURCE_TYPE
7
- from src.utilities.type_hints import input_floatlist
8
- from src.utilities.utilities import get_system_info
9
-
10
-
11
- def samgeo_fast_predict(
12
- bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT, source_type: str = SOURCE_TYPE, crs="EPSG:4326"
13
- ) -> dict:
14
- import tempfile
15
- from samgeo import tms_to_geotiff
16
- from samgeo.fast_sam import SamGeo
17
-
18
- try:
19
- os.environ['MPLCONFIGDIR'] = root_folder
20
- get_system_info()
21
- except Exception as e:
22
- app_logger.error(f"Error while setting 'MPLCONFIGDIR':{e}.")
23
-
24
- with tempfile.NamedTemporaryFile(prefix=f"{source_type}_", suffix=".tif", dir=root_folder) as image_input_tmp:
25
- app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
26
- for coord in bbox:
27
- app_logger.info(f"coord:{coord}, type:{type(coord)}.")
28
-
29
- # bbox: image input coordinate
30
- tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source=source_type, overwrite=True, crs=crs)
31
- app_logger.info(f"geotiff created, start to initialize SamGeo instance (read model {model_name} from {root_folder})...")
32
-
33
- predictor = SamGeo(
34
- model_type=model_name,
35
- checkpoint_dir=root_folder,
36
- automatic=False,
37
- sam_kwargs=None,
38
- )
39
- app_logger.info(f"initialized SamGeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
40
- predictor.set_image(image_input_tmp.name)
41
-
42
- with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
43
- app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
44
- predictor.everything_prompt(output=image_output_tmp.name)
45
-
46
- # geotiff to geojson
47
- with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
48
- app_logger.info(f"done prediction, start conversion SamGeo.tiff_to_geojson({image_output_tmp.name}) => {vector_tmp.name}.")
49
- predictor.tiff_to_geojson(image_output_tmp.name, vector_tmp.name, bidx=1)
50
-
51
- app_logger.info(f"start reading geojson {vector_tmp.name}...")
52
- with open(vector_tmp.name) as out_gdf:
53
- out_gdf_str = out_gdf.read()
54
- out_gdf_json = json.loads(out_gdf_str)
55
- app_logger.info(f"geojson {vector_tmp.name} string has length: {len(out_gdf_str)}.")
56
- return out_gdf_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utilities/constants.py CHANGED
@@ -22,7 +22,8 @@ CUSTOM_RESPONSE_MESSAGES = {
22
  422: "Missing required parameter",
23
  500: "Internal server error"
24
  }
25
- MODEL_NAME = "FastSAM-s.pt"
 
26
  ZOOM = 13
27
  SOURCE_TYPE = "Satellite"
28
 
 
22
  422: "Missing required parameter",
23
  500: "Internal server error"
24
  }
25
+ MODEL_ENCODER_NAME = "mobile_sam.encoder.onnx"
26
+ MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
27
  ZOOM = 13
28
  SOURCE_TYPE = "Satellite"
29
 
src/utilities/serialize.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Serialize objects"""
2
+ from typing import Mapping
3
+
4
+ from src import app_logger
5
+ from src.utilities.type_hints import ts_dict_str2, ts_dict_str3
6
+
7
+
8
+ def serialize(obj: any, include_none: bool = False) -> object:
9
+ """
10
+ Return the input object into a serializable one
11
+
12
+ Args:
13
+ obj: Object to serialize
14
+ include_none: bool to indicate if include also keys with None values during dict serialization
15
+
16
+ Returns:
17
+ object: serialized object
18
+
19
+ """
20
+ return _serialize(obj, include_none)
21
+
22
+
23
+ def _serialize(obj: any, include_none: bool) -> any:
24
+ import numpy as np
25
+
26
+ primitive = (int, float, str, bool)
27
+ # print(type(obj))
28
+ try:
29
+ if obj is None:
30
+ return None
31
+ elif isinstance(obj, np.integer):
32
+ return int(obj)
33
+ elif isinstance(obj, np.floating):
34
+ return float(obj)
35
+ elif isinstance(obj, np.ndarray):
36
+ return obj.tolist()
37
+ elif isinstance(obj, primitive):
38
+ return obj
39
+ elif type(obj) is list:
40
+ return _serialize_list(obj, include_none)
41
+ elif type(obj) is tuple:
42
+ return list(obj)
43
+ elif type(obj) is bytes:
44
+ return _serialize_bytes(obj)
45
+ elif isinstance(obj, Exception):
46
+ return _serialize_exception(obj)
47
+ # elif isinstance(obj, object):
48
+ # return _serialize_object(obj, include_none)
49
+ else:
50
+ return _serialize_object(obj, include_none)
51
+ except Exception as e_serialize:
52
+ app_logger.error(f"e_serialize::{e_serialize}, type_obj:{type(obj)}, obj:{obj}.")
53
+ return f"object_name:{str(obj)}__object_type_str:{str(type(obj))}."
54
+
55
+
56
+ def _serialize_object(obj: Mapping[any, object], include_none: bool) -> dict[any]:
57
+ from bson import ObjectId
58
+
59
+ res = {}
60
+ if type(obj) is not dict:
61
+ keys = [i for i in obj.__dict__.keys() if (getattr(obj, i) is not None) or include_none]
62
+ else:
63
+ keys = [i for i in obj.keys() if (obj[i] is not None) or include_none]
64
+ for key in keys:
65
+ if type(obj) is not dict:
66
+ res[key] = _serialize(getattr(obj, key), include_none)
67
+ elif isinstance(obj[key], ObjectId):
68
+ continue
69
+ else:
70
+ res[key] = _serialize(obj[key], include_none)
71
+ return res
72
+
73
+
74
+ def _serialize_list(ls: list, include_none: bool) -> list:
75
+ return [_serialize(elem, include_none) for elem in ls]
76
+
77
+
78
+ def _serialize_bytes(b: bytes) -> ts_dict_str2:
79
+ import base64
80
+ encoded = base64.b64encode(b)
81
+ return {"value": encoded.decode('ascii'), "type": "bytes"}
82
+
83
+
84
+ def _serialize_exception(e: Exception) -> ts_dict_str3:
85
+ return {"msg": str(e), "type": str(type(e)), **e.__dict__}