aletrn commited on
Commit
73bb7b5
·
1 Parent(s): 25c63a5

[feat] use temp files on SamGeo prediction

Browse files
Files changed (2) hide show
  1. src/main.py +6 -3
  2. src/prediction_api/predictor.py +34 -34
src/main.py CHANGED
@@ -92,9 +92,12 @@ def samgeo(request_input: Input):
92
  bbox=request_body["bbox"],
93
  point_coords=request_body["point"]
94
  )
95
- body = {"duration_run": time.time() - time_start_run}
96
- app_logger.info(f"body:{body}.")
97
- body["output"] = output
 
 
 
98
  return JSONResponse(status_code=200, content={"body": json.dumps(body)})
99
  except Exception as inference_exception:
100
  home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
 
92
  bbox=request_body["bbox"],
93
  point_coords=request_body["point"]
94
  )
95
+ duration_run = time.time() - time_start_run
96
+ app_logger.info(f"duration_run:{duration_run}.")
97
+ body = {
98
+ "duration_run": duration_run,
99
+ "output": output
100
+ }
101
  return JSONResponse(status_code=200, content={"body": json.dumps(body)})
102
  except Exception as inference_exception:
103
  home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
src/prediction_api/predictor.py CHANGED
@@ -9,39 +9,39 @@ from src.utilities.type_hints import input_floatlist, input_floatlist2
9
  def base_predict(
10
  bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
11
  ) -> str:
 
12
  from samgeo import SamGeo, tms_to_geotiff
13
 
14
- image = f"{root_folder}/satellite.tif"
15
- app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}.")
16
- for coord in bbox:
17
- app_logger.info(f"coord:{coord}, type:{type(coord)}.")
18
- # bbox: image input coordinate
19
- tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
20
-
21
- app_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
22
-
23
- predictor = SamGeo(
24
- model_type=model_name,
25
- checkpoint_dir=root_folder,
26
- automatic=False,
27
- sam_kwargs=None,
28
- )
29
- app_logger.info(f"initialized samgeo instance, start to set_image {image}...")
30
- predictor.set_image(image)
31
- output_name = f"{root_folder}/output.tif"
32
-
33
- app_logger.info(f"done set_image, start prediction...")
34
- predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
35
-
36
- app_logger.info(f"done prediction, start tiff to geojson conversion...")
37
-
38
- # geotiff to geojson
39
- vector = f"{root_folder}/feats.geojson"
40
- predictor.tiff_to_geojson(output_name, vector, bidx=1)
41
- app_logger.info(f"start reading geojson...")
42
-
43
- with open(vector) as out_gdf:
44
- out_gdf_str = out_gdf.read()
45
- out_gdf_json = json.loads(out_gdf_str)
46
- app_logger.info(f"geojson string output length:{len(out_gdf_str)}.")
47
- return out_gdf_json
 
9
  def base_predict(
10
  bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
11
  ) -> str:
12
+ import tempfile
13
  from samgeo import SamGeo, tms_to_geotiff
14
 
15
+ with tempfile.NamedTemporaryFile(prefix="satellite_", suffix=".tif", dir=root_folder) as image_input_tmp:
16
+ app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
17
+ for coord in bbox:
18
+ app_logger.info(f"coord:{coord}, type:{type(coord)}.")
19
+
20
+ # bbox: image input coordinate
21
+ tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
22
+ app_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
23
+
24
+ predictor = SamGeo(
25
+ model_type=model_name,
26
+ checkpoint_dir=root_folder,
27
+ automatic=False,
28
+ sam_kwargs=None,
29
+ )
30
+ app_logger.info(f"initialized samgeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
31
+ predictor.set_image(image_input_tmp.name)
32
+
33
+ with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
34
+ app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
35
+ predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=image_output_tmp.name)
36
+
37
+ # geotiff to geojson
38
+ with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
39
+ app_logger.info(f"done prediction, start conversion SamGeo.tiff_to_geojson({image_output_tmp.name}) => {vector_tmp.name}.")
40
+ predictor.tiff_to_geojson(image_output_tmp.name, vector_tmp.name, bidx=1)
41
+
42
+ app_logger.info(f"start reading geojson {vector_tmp.name}...")
43
+ with open(vector_tmp.name) as out_gdf:
44
+ out_gdf_str = out_gdf.read()
45
+ out_gdf_json = json.loads(out_gdf_str)
46
+ app_logger.info(f"geojson {vector_tmp.name} string has length: {len(out_gdf_str)}.")
47
+ return out_gdf_json