aletrn commited on
Commit
14effdf
·
1 Parent(s): 59ebac7

[fix] reformat code, use input from request

Browse files
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv/
2
+ *.pyc
3
+ __cache__
4
+ .idea
5
+ tmp/
6
+ .env*
dockerfiles/dockerfile-base-webserver ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /code
4
+
5
+ RUN which python
6
+ RUN python --version
7
+ RUN python -m pip install --no-cache-dir fastapi uvicorn
8
+
9
+ RUN useradd -m -u 1000 user
10
+
11
+ USER user
12
+
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR $HOME/app
17
+
18
+ RUN ls -l ${HOME}/
19
+ COPY --chown=user src $HOME/app/src
20
+ COPY --chown=user static $HOME/app/static
21
+
22
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
Dockerfile → dockerfiles/dockerfile-samgeo RENAMED
File without changes
src/main.py CHANGED
@@ -6,29 +6,42 @@ from fastapi.responses import FileResponse, JSONResponse
6
  from fastapi.staticfiles import StaticFiles
7
  from pydantic import BaseModel
8
 
 
9
  from src.utilities.utilities import setup_logging
10
 
11
 
12
  app = FastAPI()
13
- local_logger = setup_logging()
14
 
15
 
16
  class Input(BaseModel):
17
- name: str
18
- bbox: List[float]
19
- points_coords: List[List[float]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  @app.post("/post_test")
23
- async def post_test(input: Input) -> JSONResponse:
24
- bbox = input.bbox
25
- name = input.name
26
- points_coords = input.points_coords
27
  return JSONResponse(
28
  status_code=200,
29
- content={
30
- "msg": name, "bbox": bbox, "points_coords": points_coords
31
- }
32
  )
33
 
34
 
@@ -38,7 +51,7 @@ async def hello() -> JSONResponse:
38
 
39
 
40
  @app.post("/infer_samgeo")
41
- def samgeo():
42
  import subprocess
43
 
44
  from src.prediction_api.predictor import base_predict
@@ -52,15 +65,16 @@ def samgeo():
52
  time_start_run = time.time()
53
  # debug = True
54
  # local_logger = setup_logging(debug)
55
- message = "point_coords_segmentation"
56
- bbox = [-122.1497, 37.6311, -122.1203, 37.6458]
57
- point_coords = [[-122.1419, 37.6383]]
58
  try:
59
- output = base_predict(bbox=bbox, point_coords=point_coords)
 
 
 
60
 
61
  duration_run = time.time() - time_start_run
62
  body = {
63
- "message": message,
64
  "duration_run": duration_run,
65
  # "request_id": request_id
66
  }
@@ -80,6 +94,7 @@ def samgeo():
80
 
81
  @app.exception_handler(RequestValidationError)
82
  async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
 
83
  local_logger.error(f"exception errors: {exc.errors()}.")
84
  local_logger.error(f"exception body: {exc.body}.")
85
  headers = request.headers.items()
@@ -94,6 +109,7 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
94
 
95
  @app.exception_handler(HTTPException)
96
  async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
 
97
  local_logger.error(f"exception: {str(exc)}.")
98
  headers = request.headers.items()
99
  local_logger.error(f'request header: {dict(headers)}.' )
 
6
  from fastapi.staticfiles import StaticFiles
7
  from pydantic import BaseModel
8
 
9
+ from src.utilities.type_hints import input_floatlist, input_floatlist2
10
  from src.utilities.utilities import setup_logging
11
 
12
 
13
  app = FastAPI()
 
14
 
15
 
16
  class Input(BaseModel):
17
+ x1: float
18
+ y1: float
19
+ x2: float
20
+ y2: float
21
+ x: float
22
+ y: float
23
+
24
+
25
+ class BBoxWithPoint(BaseModel):
26
+ bbox: input_floatlist
27
+ point: input_floatlist2
28
+
29
+
30
+ def get_parsed_bbox_points(request_input: Input) -> BBoxWithPoint:
31
+ return {
32
+ "bbox": [
33
+ request_input.x1, request_input.x2,
34
+ request_input.y1, request_input.y2
35
+ ],
36
+ "point": [[request_input.x, request_input.y]]
37
+ }
38
 
39
 
40
  @app.post("/post_test")
41
+ async def post_test(request_input: Input) -> JSONResponse:
 
 
 
42
  return JSONResponse(
43
  status_code=200,
44
+ content=get_parsed_bbox_points(request_input)
 
 
45
  )
46
 
47
 
 
51
 
52
 
53
  @app.post("/infer_samgeo")
54
+ def samgeo(request_input: Input):
55
  import subprocess
56
 
57
  from src.prediction_api.predictor import base_predict
 
65
  time_start_run = time.time()
66
  # debug = True
67
  # local_logger = setup_logging(debug)
68
+ request_body = get_parsed_bbox_points(request_input)
69
+ local_logger.info(f"request_body:{request_body}.")
 
70
  try:
71
+ output = base_predict(
72
+ bbox=request_body["bbox"],
73
+ point_coords=request_body["point"]
74
+ )
75
 
76
  duration_run = time.time() - time_start_run
77
  body = {
 
78
  "duration_run": duration_run,
79
  # "request_id": request_id
80
  }
 
94
 
95
  @app.exception_handler(RequestValidationError)
96
  async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
97
+ local_logger = setup_logging()
98
  local_logger.error(f"exception errors: {exc.errors()}.")
99
  local_logger.error(f"exception body: {exc.body}.")
100
  headers = request.headers.items()
 
109
 
110
  @app.exception_handler(HTTPException)
111
  async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
112
+ local_logger = setup_logging()
113
  local_logger.error(f"exception: {str(exc)}.")
114
  headers = request.headers.items()
115
  local_logger.error(f'request header: {dict(headers)}.' )
src/prediction_api/predictor.py CHANGED
@@ -2,21 +2,26 @@
2
  import json
3
 
4
  from src.utilities.constants import ROOT
 
5
  from src.utilities.utilities import setup_logging
6
 
7
-
8
  local_logger = setup_logging()
9
 
10
 
11
- def base_predict(bbox, point_coords, point_crs="EPSG:4326", zoom=16, model_name:str="vit_h", root_folder:str=ROOT) -> str:
 
 
12
  from samgeo import SamGeo, tms_to_geotiff
13
 
14
  image = f"{root_folder}/satellite.tif"
15
- local_logger.info("start tms_to_geotiff")
 
 
16
  # bbox: image input coordinate
17
  tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
18
 
19
  local_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
 
20
  predictor = SamGeo(
21
  model_type=model_name,
22
  checkpoint_dir=root_folder,
@@ -26,10 +31,10 @@ def base_predict(bbox, point_coords, point_crs="EPSG:4326", zoom=16, model_name:
26
  local_logger.info(f"initialized samgeo instance, start to set_image {image}...")
27
  predictor.set_image(image)
28
  output_name = f"{root_folder}/output.tif"
29
-
30
  local_logger.info(f"done set_image, start prediction...")
31
  predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
32
-
33
  local_logger.info(f"done prediction, start tiff to geojson conversion...")
34
 
35
  # geotiff to geojson
 
2
  import json
3
 
4
  from src.utilities.constants import ROOT
5
+ from src.utilities.type_hints import input_floatlist, input_floatlist2
6
  from src.utilities.utilities import setup_logging
7
 
 
8
  local_logger = setup_logging()
9
 
10
 
11
+ def base_predict(
12
+ bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
13
+ ) -> str:
14
  from samgeo import SamGeo, tms_to_geotiff
15
 
16
  image = f"{root_folder}/satellite.tif"
17
+ local_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}.")
18
+ for coord in bbox:
19
+ local_logger.info(f"coord:{coord}, type:{type(coord)}.")
20
  # bbox: image input coordinate
21
  tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
22
 
23
  local_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
24
+
25
  predictor = SamGeo(
26
  model_type=model_name,
27
  checkpoint_dir=root_folder,
 
31
  local_logger.info(f"initialized samgeo instance, start to set_image {image}...")
32
  predictor.set_image(image)
33
  output_name = f"{root_folder}/output.tif"
34
+
35
  local_logger.info(f"done set_image, start prediction...")
36
  predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
37
+
38
  local_logger.info(f"done prediction, start tiff to geojson conversion...")
39
 
40
  # geotiff to geojson
src/utilities/type_hints.py CHANGED
@@ -11,20 +11,3 @@ ts_dict_str2 = Dict[str, str]
11
  ts_dict_str3 = Dict[str, str, any]
12
  ts_float64_1 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
13
  ts_float64_2 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
14
-
15
- """
16
- ts_list_str1 = List[str]
17
- ts_http2 = Tuple[ts_list_str1, ts_list_str1]
18
- ts_list_float2 = List[float, float]
19
- ts_llist_float2 = List[ts_list_float2, ts_list_float2]
20
- ts_geojson = Dict[str, str, Dict[str, Dict[str]], List[str, Dict[int], Dict[str, List]]]
21
- ts_dict_str2b = Dict[str, any]
22
- ts_ddict2 = Dict[str, Dict, Dict[str, List]]
23
- ts_tuple_str2 = Tuple[str, str]
24
- ts_tuple_arr2 = Tuple[np.ndarray, np.ndarray]
25
- ts_tuple_flat2 = Tuple[float, float]
26
- ts_tuple_flat4 = Tuple[float, float, float, float]
27
- ts_list_float4 = List[float, float, float, float]
28
- ts_tuple_int4 = Tuple[int, int, int, int]
29
- ts_ddict3 = Dict[List[Dict[float | int | str]], Dict[float | int]]
30
- """
 
11
  ts_dict_str3 = Dict[str, str, any]
12
  ts_float64_1 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
13
  ts_float64_2 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
static/app.js ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const jsonBtn = document.getElementById("getJson");
2
+ const apiBtn = document.getElementById("getApi");
3
+ const output = document.getElementById("output");
4
+ const coordsForm = document.getElementById("coords-form");
5
+
6
+ function formData2json(dataId, newObj={}) {
7
+ const formData = new FormData(dataId);
8
+ formData.forEach(function(value, key){
9
+ newObj[key] = value;
10
+ });
11
+ return JSON.stringify(newObj);
12
+ }
13
+
14
+ coordsForm.addEventListener('submit', event => {
15
+ event.preventDefault();
16
+ const inputJson = formData2json(coordsForm)
17
+ console.log("inputJson", inputJson, "#");
18
+
19
+ fetch('/infer_samgeo', {
20
+ method: 'POST', // or 'PUT'
21
+ body: inputJson, // a FormData will automatically set the 'Content-Type',
22
+ headers: {"Content-Type": "application/json"},
23
+ }).then(function (response) {
24
+ return response.json();
25
+ }).then(function (data) {
26
+ console.log("data:", data, "#")
27
+ output.innerHTML = JSON.stringify(data)
28
+ }).catch(function (err) {
29
+ console.log("err:", err, "#")
30
+ output.innerHTML = `err:${JSON.stringify(err)}.`;
31
+ });
32
+ event.preventDefault();
33
+ });
static/index.html CHANGED
@@ -1,24 +1,41 @@
1
- <main>
2
- <section id="coords-gen">
3
- <h2>Segment Geospatial: Instance segmentation within images based on Segment Anything</h2>
4
- <p>
5
- Model:
6
- <a
7
- href="https://github.com/opengeos/segment-geospatial/"
8
- rel="noreferrer"
9
- target="_blank"
10
- >opengeos/segment-geospatial
11
- </a>
12
- </p>
13
- <form class="coords-gen-form">
14
- <label for="coords-gen-input">Text prompt</label>
15
- <input
16
- id="coords-gen-input"
17
- type="coords"
18
- value=""
19
- />
20
- <button id="coords-gen-submit">Submit</button>
21
- <p class="coords-gen-output"></p>
22
- </form>
23
- </section>
24
- </main>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <meta http-equiv="X-UA-Compatible" content="ie=edge">
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.min.css" />
8
+ <title>Fetch API</title>
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>Fetch API</h1>
13
+ <form id="coords-form">
14
+ <label>
15
+ bbox: x1 <input type="number" id="x-center-form" name="x1" value="-122.1497"/>
16
+ </label>
17
+ <label>
18
+ bbox: x2 <input type="number" id="y-center-form" name="x2" value="37.6311"/>
19
+ </label>
20
+ <label>
21
+ bbox: y1 <input type="number" id="y-center-form" name="y1" value="-122.1203"/>
22
+ </label>
23
+ <label>
24
+ bbox: y2 <input type="number" id="y-center-form" name="y2" value="37.6458"/>
25
+ </label>
26
+ <br/><br/>
27
+ <label>
28
+ x point: <input type="number" id="x-form" name="x" value="-122.1419"/>
29
+ </label>
30
+ <label>
31
+ y point: <input type="number" id="y-form" name="y" value="37.6383"/>
32
+ </label>
33
+ <br/><br/>
34
+ <button type="submit" id="submit-btn">submit form</button>
35
+ </form>
36
+ <br><br>
37
+ <div id="output"></div>
38
+ </div>
39
+ <script src="app.js"></script>
40
+ </body>
41
+ </html>
static/script.js DELETED
@@ -1,17 +0,0 @@
1
- const coordsGenForm = document.querySelector(".coords-gen-form");
2
-
3
- const segmentCoords = async (coords) => {
4
- const inferResponse = await fetch(`infer_samgeo?input=${coords}`);
5
- const inferJson = await inferResponse.json();
6
-
7
- return inferJson.output;
8
- };
9
-
10
- coordsGenForm.addEventListener("submit", async (event) => {
11
- event.preventDefault();
12
-
13
- const coordsGenInput = document.getElementById("coords-gen-input");
14
- const coordsGenParagraph = document.querySelector(".coords-gen-output");
15
-
16
- coordsGenParagraph.coordsContent = await segmentCoords(coordsGenInput.value);
17
- });