[refactor] rename predictor.py, move ROOT folder to /tmp to avoid RO filesystem errors
Browse files
src/app.py
CHANGED
@@ -9,8 +9,8 @@ from geojson_pydantic import FeatureCollection, Feature, Polygon
|
|
9 |
from pydantic import BaseModel, ValidationError
|
10 |
|
11 |
from src import app_logger
|
12 |
-
from src.prediction_api.
|
13 |
-
from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM
|
14 |
from src.utilities.utilities import base64_decode
|
15 |
|
16 |
PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
|
@@ -22,17 +22,18 @@ class LatLngTupleLeaflet(BaseModel):
|
|
22 |
|
23 |
|
24 |
class RequestBody(BaseModel):
|
|
|
25 |
ne: LatLngTupleLeaflet
|
|
|
26 |
sw: LatLngTupleLeaflet
|
27 |
-
model: str = MODEL_NAME
|
28 |
zoom: float = ZOOM
|
29 |
|
30 |
|
31 |
class ResponseBody(BaseModel):
|
32 |
-
request_id: str = None
|
33 |
duration_run: float = None
|
34 |
-
message: str = None
|
35 |
geojson: Dict = None
|
|
|
|
|
36 |
|
37 |
|
38 |
def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
|
@@ -67,15 +68,17 @@ def get_response(status: int, start_time: float, request_id: str, response_body:
|
|
67 |
def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
|
68 |
model_name = request_input["model"] if "model" in request_input else MODEL_NAME
|
69 |
zoom = request_input["zoom"] if "zoom" in request_input else ZOOM
|
|
|
70 |
app_logger.info(f"try to validate input request {request_input}...")
|
71 |
-
request_body = RequestBody(ne=request_input["ne"], sw=request_input["sw"], model=model_name, zoom=zoom)
|
72 |
return {
|
73 |
"bbox": [
|
74 |
request_body.ne.lat, request_body.sw.lat,
|
75 |
request_body.ne.lng, request_body.sw.lng
|
76 |
],
|
77 |
"model": request_body.model,
|
78 |
-
"zoom": request_body.zoom
|
|
|
79 |
}
|
80 |
|
81 |
|
@@ -108,7 +111,9 @@ def lambda_handler(event: dict, context: LambdaContext):
|
|
108 |
try:
|
109 |
body_request = get_parsed_bbox_points(body)
|
110 |
app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
|
111 |
-
output_geojson_dict =
|
|
|
|
|
112 |
|
113 |
# raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
|
114 |
PolygonFeatureCollectionModel(**output_geojson_dict)
|
|
|
9 |
from pydantic import BaseModel, ValidationError
|
10 |
|
11 |
from src import app_logger
|
12 |
+
from src.prediction_api.samgeo_predictors import samgeo_fast_predict
|
13 |
+
from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM, SOURCE_TYPE
|
14 |
from src.utilities.utilities import base64_decode
|
15 |
|
16 |
PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
|
|
|
22 |
|
23 |
|
24 |
class RequestBody(BaseModel):
|
25 |
+
model: str = MODEL_NAME
|
26 |
ne: LatLngTupleLeaflet
|
27 |
+
source_type: str = SOURCE_TYPE
|
28 |
sw: LatLngTupleLeaflet
|
|
|
29 |
zoom: float = ZOOM
|
30 |
|
31 |
|
32 |
class ResponseBody(BaseModel):
|
|
|
33 |
duration_run: float = None
|
|
|
34 |
geojson: Dict = None
|
35 |
+
message: str = None
|
36 |
+
request_id: str = None
|
37 |
|
38 |
|
39 |
def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
|
|
|
68 |
def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
|
69 |
model_name = request_input["model"] if "model" in request_input else MODEL_NAME
|
70 |
zoom = request_input["zoom"] if "zoom" in request_input else ZOOM
|
71 |
+
source_type = request_input["source_type"] if "zoom" in request_input else SOURCE_TYPE
|
72 |
app_logger.info(f"try to validate input request {request_input}...")
|
73 |
+
request_body = RequestBody(ne=request_input["ne"], sw=request_input["sw"], model=model_name, zoom=zoom, source_type=source_type)
|
74 |
return {
|
75 |
"bbox": [
|
76 |
request_body.ne.lat, request_body.sw.lat,
|
77 |
request_body.ne.lng, request_body.sw.lng
|
78 |
],
|
79 |
"model": request_body.model,
|
80 |
+
"zoom": request_body.zoom,
|
81 |
+
"source_type": request_body.source_type
|
82 |
}
|
83 |
|
84 |
|
|
|
111 |
try:
|
112 |
body_request = get_parsed_bbox_points(body)
|
113 |
app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
|
114 |
+
output_geojson_dict = samgeo_fast_predict(
|
115 |
+
bbox=body_request["bbox"], model_name=body_request["model"], zoom=body_request["zoom"], source_type=body_request["source_type"]
|
116 |
+
)
|
117 |
|
118 |
# raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
|
119 |
PolygonFeatureCollectionModel(**output_geojson_dict)
|
src/prediction_api/{predictor.py → samgeo_predictors.py}
RENAMED
@@ -2,23 +2,23 @@
|
|
2 |
import json
|
3 |
|
4 |
from src import app_logger
|
5 |
-
from src.utilities.constants import ROOT, MODEL_NAME, ZOOM
|
6 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
7 |
|
8 |
|
9 |
-
def
|
10 |
import tempfile
|
11 |
from samgeo import tms_to_geotiff
|
12 |
from samgeo.fast_sam import SamGeo
|
13 |
|
14 |
-
with tempfile.NamedTemporaryFile(prefix="
|
15 |
app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
|
16 |
for coord in bbox:
|
17 |
app_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
18 |
|
19 |
# bbox: image input coordinate
|
20 |
-
tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source=
|
21 |
-
app_logger.info(f"geotiff created, start to initialize
|
22 |
|
23 |
predictor = SamGeo(
|
24 |
model_type=model_name,
|
@@ -26,7 +26,7 @@ def base_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MO
|
|
26 |
automatic=False,
|
27 |
sam_kwargs=None,
|
28 |
)
|
29 |
-
app_logger.info(f"initialized
|
30 |
predictor.set_image(image_input_tmp.name)
|
31 |
|
32 |
with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
|
|
|
2 |
import json
|
3 |
|
4 |
from src import app_logger
|
5 |
+
from src.utilities.constants import ROOT, MODEL_NAME, ZOOM, SOURCE_TYPE
|
6 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
7 |
|
8 |
|
9 |
+
def samgeo_fast_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT, source_type: str = SOURCE_TYPE) -> dict:
|
10 |
import tempfile
|
11 |
from samgeo import tms_to_geotiff
|
12 |
from samgeo.fast_sam import SamGeo
|
13 |
|
14 |
+
with tempfile.NamedTemporaryFile(prefix=f"{source_type}_", suffix=".tif", dir=root_folder) as image_input_tmp:
|
15 |
app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
|
16 |
for coord in bbox:
|
17 |
app_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
18 |
|
19 |
# bbox: image input coordinate
|
20 |
+
tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source=source_type, overwrite=True)
|
21 |
+
app_logger.info(f"geotiff created, start to initialize SamGeo instance (read model {model_name} from {root_folder})...")
|
22 |
|
23 |
predictor = SamGeo(
|
24 |
model_type=model_name,
|
|
|
26 |
automatic=False,
|
27 |
sam_kwargs=None,
|
28 |
)
|
29 |
+
app_logger.info(f"initialized SamGeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
|
30 |
predictor.set_image(image_input_tmp.name)
|
31 |
|
32 |
with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
|
src/utilities/constants.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
CHANNEL_EXAGGERATIONS_LIST = [2.5, 1.1, 2.0]
|
3 |
INPUT_CRS_STRING = "EPSG:4326"
|
4 |
OUTPUT_CRS_STRING = "EPSG:3857"
|
5 |
-
ROOT = "/
|
6 |
NODATA_VALUES = -32768
|
7 |
SKIP_CONDITIONS_LIST = [{"skip_key": "confidence", "skip_value": 0.5, "skip_condition": "major"}]
|
8 |
FEATURE_SQUARE_TEMPLATE = [
|
@@ -24,3 +24,4 @@ CUSTOM_RESPONSE_MESSAGES = {
|
|
24 |
}
|
25 |
MODEL_NAME = "FastSAM-s.pt"
|
26 |
ZOOM = 13
|
|
|
|
2 |
CHANNEL_EXAGGERATIONS_LIST = [2.5, 1.1, 2.0]
|
3 |
INPUT_CRS_STRING = "EPSG:4326"
|
4 |
OUTPUT_CRS_STRING = "EPSG:3857"
|
5 |
+
ROOT = "/tmp"
|
6 |
NODATA_VALUES = -32768
|
7 |
SKIP_CONDITIONS_LIST = [{"skip_key": "confidence", "skip_value": 0.5, "skip_condition": "major"}]
|
8 |
FEATURE_SQUARE_TEMPLATE = [
|
|
|
24 |
}
|
25 |
MODEL_NAME = "FastSAM-s.pt"
|
26 |
ZOOM = 13
|
27 |
+
SOURCE_TYPE = "Satellite"
|