aletrn commited on
Commit
e945b1f
·
1 Parent(s): 26b4f04

[refactor] rename predictor.py, move ROOT folder to /tmp to avoid RO filesystem errors

Browse files
src/app.py CHANGED
@@ -9,8 +9,8 @@ from geojson_pydantic import FeatureCollection, Feature, Polygon
9
  from pydantic import BaseModel, ValidationError
10
 
11
  from src import app_logger
12
- from src.prediction_api.predictor import base_predict
13
- from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM
14
  from src.utilities.utilities import base64_decode
15
 
16
  PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
@@ -22,17 +22,18 @@ class LatLngTupleLeaflet(BaseModel):
22
 
23
 
24
  class RequestBody(BaseModel):
 
25
  ne: LatLngTupleLeaflet
 
26
  sw: LatLngTupleLeaflet
27
- model: str = MODEL_NAME
28
  zoom: float = ZOOM
29
 
30
 
31
  class ResponseBody(BaseModel):
32
- request_id: str = None
33
  duration_run: float = None
34
- message: str = None
35
  geojson: Dict = None
 
 
36
 
37
 
38
  def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
@@ -67,15 +68,17 @@ def get_response(status: int, start_time: float, request_id: str, response_body:
67
  def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
68
  model_name = request_input["model"] if "model" in request_input else MODEL_NAME
69
  zoom = request_input["zoom"] if "zoom" in request_input else ZOOM
 
70
  app_logger.info(f"try to validate input request {request_input}...")
71
- request_body = RequestBody(ne=request_input["ne"], sw=request_input["sw"], model=model_name, zoom=zoom)
72
  return {
73
  "bbox": [
74
  request_body.ne.lat, request_body.sw.lat,
75
  request_body.ne.lng, request_body.sw.lng
76
  ],
77
  "model": request_body.model,
78
- "zoom": request_body.zoom
 
79
  }
80
 
81
 
@@ -108,7 +111,9 @@ def lambda_handler(event: dict, context: LambdaContext):
108
  try:
109
  body_request = get_parsed_bbox_points(body)
110
  app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
111
- output_geojson_dict = base_predict(bbox=body_request["bbox"], model_name=body_request["model"], zoom=body_request["zoom"])
 
 
112
 
113
  # raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
114
  PolygonFeatureCollectionModel(**output_geojson_dict)
 
9
  from pydantic import BaseModel, ValidationError
10
 
11
  from src import app_logger
12
+ from src.prediction_api.samgeo_predictors import samgeo_fast_predict
13
+ from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM, SOURCE_TYPE
14
  from src.utilities.utilities import base64_decode
15
 
16
  PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
 
22
 
23
 
24
  class RequestBody(BaseModel):
25
+ model: str = MODEL_NAME
26
  ne: LatLngTupleLeaflet
27
+ source_type: str = SOURCE_TYPE
28
  sw: LatLngTupleLeaflet
 
29
  zoom: float = ZOOM
30
 
31
 
32
  class ResponseBody(BaseModel):
 
33
  duration_run: float = None
 
34
  geojson: Dict = None
35
+ message: str = None
36
+ request_id: str = None
37
 
38
 
39
  def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
 
68
  def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
69
  model_name = request_input["model"] if "model" in request_input else MODEL_NAME
70
  zoom = request_input["zoom"] if "zoom" in request_input else ZOOM
71
+ source_type = request_input["source_type"] if "zoom" in request_input else SOURCE_TYPE
72
  app_logger.info(f"try to validate input request {request_input}...")
73
+ request_body = RequestBody(ne=request_input["ne"], sw=request_input["sw"], model=model_name, zoom=zoom, source_type=source_type)
74
  return {
75
  "bbox": [
76
  request_body.ne.lat, request_body.sw.lat,
77
  request_body.ne.lng, request_body.sw.lng
78
  ],
79
  "model": request_body.model,
80
+ "zoom": request_body.zoom,
81
+ "source_type": request_body.source_type
82
  }
83
 
84
 
 
111
  try:
112
  body_request = get_parsed_bbox_points(body)
113
  app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
114
+ output_geojson_dict = samgeo_fast_predict(
115
+ bbox=body_request["bbox"], model_name=body_request["model"], zoom=body_request["zoom"], source_type=body_request["source_type"]
116
+ )
117
 
118
  # raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
119
  PolygonFeatureCollectionModel(**output_geojson_dict)
src/prediction_api/{predictor.py → samgeo_predictors.py} RENAMED
@@ -2,23 +2,23 @@
2
  import json
3
 
4
  from src import app_logger
5
- from src.utilities.constants import ROOT, MODEL_NAME, ZOOM
6
  from src.utilities.type_hints import input_floatlist, input_floatlist2
7
 
8
 
9
- def base_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT) -> dict:
10
  import tempfile
11
  from samgeo import tms_to_geotiff
12
  from samgeo.fast_sam import SamGeo
13
 
14
- with tempfile.NamedTemporaryFile(prefix="satellite_", suffix=".tif", dir=root_folder) as image_input_tmp:
15
  app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
16
  for coord in bbox:
17
  app_logger.info(f"coord:{coord}, type:{type(coord)}.")
18
 
19
  # bbox: image input coordinate
20
- tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
21
- app_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
22
 
23
  predictor = SamGeo(
24
  model_type=model_name,
@@ -26,7 +26,7 @@ def base_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MO
26
  automatic=False,
27
  sam_kwargs=None,
28
  )
29
- app_logger.info(f"initialized samgeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
30
  predictor.set_image(image_input_tmp.name)
31
 
32
  with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
 
2
  import json
3
 
4
  from src import app_logger
5
+ from src.utilities.constants import ROOT, MODEL_NAME, ZOOM, SOURCE_TYPE
6
  from src.utilities.type_hints import input_floatlist, input_floatlist2
7
 
8
 
9
+ def samgeo_fast_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT, source_type: str = SOURCE_TYPE) -> dict:
10
  import tempfile
11
  from samgeo import tms_to_geotiff
12
  from samgeo.fast_sam import SamGeo
13
 
14
+ with tempfile.NamedTemporaryFile(prefix=f"{source_type}_", suffix=".tif", dir=root_folder) as image_input_tmp:
15
  app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
16
  for coord in bbox:
17
  app_logger.info(f"coord:{coord}, type:{type(coord)}.")
18
 
19
  # bbox: image input coordinate
20
+ tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source=source_type, overwrite=True)
21
+ app_logger.info(f"geotiff created, start to initialize SamGeo instance (read model {model_name} from {root_folder})...")
22
 
23
  predictor = SamGeo(
24
  model_type=model_name,
 
26
  automatic=False,
27
  sam_kwargs=None,
28
  )
29
+ app_logger.info(f"initialized SamGeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
30
  predictor.set_image(image_input_tmp.name)
31
 
32
  with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
src/utilities/constants.py CHANGED
@@ -2,7 +2,7 @@
2
  CHANNEL_EXAGGERATIONS_LIST = [2.5, 1.1, 2.0]
3
  INPUT_CRS_STRING = "EPSG:4326"
4
  OUTPUT_CRS_STRING = "EPSG:3857"
5
- ROOT = "/home/user"
6
  NODATA_VALUES = -32768
7
  SKIP_CONDITIONS_LIST = [{"skip_key": "confidence", "skip_value": 0.5, "skip_condition": "major"}]
8
  FEATURE_SQUARE_TEMPLATE = [
@@ -24,3 +24,4 @@ CUSTOM_RESPONSE_MESSAGES = {
24
  }
25
  MODEL_NAME = "FastSAM-s.pt"
26
  ZOOM = 13
 
 
2
  CHANNEL_EXAGGERATIONS_LIST = [2.5, 1.1, 2.0]
3
  INPUT_CRS_STRING = "EPSG:4326"
4
  OUTPUT_CRS_STRING = "EPSG:3857"
5
+ ROOT = "/tmp"
6
  NODATA_VALUES = -32768
7
  SKIP_CONDITIONS_LIST = [{"skip_key": "confidence", "skip_value": 0.5, "skip_condition": "major"}]
8
  FEATURE_SQUARE_TEMPLATE = [
 
24
  }
25
  MODEL_NAME = "FastSAM-s.pt"
26
  ZOOM = 13
27
+ SOURCE_TYPE = "Satellite"