[fix] reformat code, use input from request
Browse files- .dockerignore +6 -0
- dockerfiles/dockerfile-base-webserver +22 -0
- Dockerfile → dockerfiles/dockerfile-samgeo +0 -0
- src/main.py +33 -17
- src/prediction_api/predictor.py +10 -5
- src/utilities/type_hints.py +0 -17
- static/app.js +33 -0
- static/index.html +41 -24
- static/script.js +0 -17
.dockerignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
*.pyc
|
3 |
+
__cache__
|
4 |
+
.idea
|
5 |
+
tmp/
|
6 |
+
.env*
|
dockerfiles/dockerfile-base-webserver
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
RUN which python
|
6 |
+
RUN python --version
|
7 |
+
RUN python -m pip install --no-cache-dir fastapi uvicorn
|
8 |
+
|
9 |
+
RUN useradd -m -u 1000 user
|
10 |
+
|
11 |
+
USER user
|
12 |
+
|
13 |
+
ENV HOME=/home/user \
|
14 |
+
PATH=/home/user/.local/bin:$PATH
|
15 |
+
|
16 |
+
WORKDIR $HOME/app
|
17 |
+
|
18 |
+
RUN ls -l ${HOME}/
|
19 |
+
COPY --chown=user src $HOME/app/src
|
20 |
+
COPY --chown=user static $HOME/app/static
|
21 |
+
|
22 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
Dockerfile → dockerfiles/dockerfile-samgeo
RENAMED
File without changes
|
src/main.py
CHANGED
@@ -6,29 +6,42 @@ from fastapi.responses import FileResponse, JSONResponse
|
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
from pydantic import BaseModel
|
8 |
|
|
|
9 |
from src.utilities.utilities import setup_logging
|
10 |
|
11 |
|
12 |
app = FastAPI()
|
13 |
-
local_logger = setup_logging()
|
14 |
|
15 |
|
16 |
class Input(BaseModel):
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
@app.post("/post_test")
|
23 |
-
async def post_test(
|
24 |
-
bbox = input.bbox
|
25 |
-
name = input.name
|
26 |
-
points_coords = input.points_coords
|
27 |
return JSONResponse(
|
28 |
status_code=200,
|
29 |
-
content=
|
30 |
-
"msg": name, "bbox": bbox, "points_coords": points_coords
|
31 |
-
}
|
32 |
)
|
33 |
|
34 |
|
@@ -38,7 +51,7 @@ async def hello() -> JSONResponse:
|
|
38 |
|
39 |
|
40 |
@app.post("/infer_samgeo")
|
41 |
-
def samgeo():
|
42 |
import subprocess
|
43 |
|
44 |
from src.prediction_api.predictor import base_predict
|
@@ -52,15 +65,16 @@ def samgeo():
|
|
52 |
time_start_run = time.time()
|
53 |
# debug = True
|
54 |
# local_logger = setup_logging(debug)
|
55 |
-
|
56 |
-
|
57 |
-
point_coords = [[-122.1419, 37.6383]]
|
58 |
try:
|
59 |
-
output = base_predict(
|
|
|
|
|
|
|
60 |
|
61 |
duration_run = time.time() - time_start_run
|
62 |
body = {
|
63 |
-
"message": message,
|
64 |
"duration_run": duration_run,
|
65 |
# "request_id": request_id
|
66 |
}
|
@@ -80,6 +94,7 @@ def samgeo():
|
|
80 |
|
81 |
@app.exception_handler(RequestValidationError)
|
82 |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
|
|
83 |
local_logger.error(f"exception errors: {exc.errors()}.")
|
84 |
local_logger.error(f"exception body: {exc.body}.")
|
85 |
headers = request.headers.items()
|
@@ -94,6 +109,7 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|
94 |
|
95 |
@app.exception_handler(HTTPException)
|
96 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
|
|
97 |
local_logger.error(f"exception: {str(exc)}.")
|
98 |
headers = request.headers.items()
|
99 |
local_logger.error(f'request header: {dict(headers)}.' )
|
|
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
+
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
10 |
from src.utilities.utilities import setup_logging
|
11 |
|
12 |
|
13 |
app = FastAPI()
|
|
|
14 |
|
15 |
|
16 |
class Input(BaseModel):
|
17 |
+
x1: float
|
18 |
+
y1: float
|
19 |
+
x2: float
|
20 |
+
y2: float
|
21 |
+
x: float
|
22 |
+
y: float
|
23 |
+
|
24 |
+
|
25 |
+
class BBoxWithPoint(BaseModel):
|
26 |
+
bbox: input_floatlist
|
27 |
+
point: input_floatlist2
|
28 |
+
|
29 |
+
|
30 |
+
def get_parsed_bbox_points(request_input: Input) -> BBoxWithPoint:
|
31 |
+
return {
|
32 |
+
"bbox": [
|
33 |
+
request_input.x1, request_input.x2,
|
34 |
+
request_input.y1, request_input.y2
|
35 |
+
],
|
36 |
+
"point": [[request_input.x, request_input.y]]
|
37 |
+
}
|
38 |
|
39 |
|
40 |
@app.post("/post_test")
|
41 |
+
async def post_test(request_input: Input) -> JSONResponse:
|
|
|
|
|
|
|
42 |
return JSONResponse(
|
43 |
status_code=200,
|
44 |
+
content=get_parsed_bbox_points(request_input)
|
|
|
|
|
45 |
)
|
46 |
|
47 |
|
|
|
51 |
|
52 |
|
53 |
@app.post("/infer_samgeo")
|
54 |
+
def samgeo(request_input: Input):
|
55 |
import subprocess
|
56 |
|
57 |
from src.prediction_api.predictor import base_predict
|
|
|
65 |
time_start_run = time.time()
|
66 |
# debug = True
|
67 |
# local_logger = setup_logging(debug)
|
68 |
+
request_body = get_parsed_bbox_points(request_input)
|
69 |
+
local_logger.info(f"request_body:{request_body}.")
|
|
|
70 |
try:
|
71 |
+
output = base_predict(
|
72 |
+
bbox=request_body["bbox"],
|
73 |
+
point_coords=request_body["point"]
|
74 |
+
)
|
75 |
|
76 |
duration_run = time.time() - time_start_run
|
77 |
body = {
|
|
|
78 |
"duration_run": duration_run,
|
79 |
# "request_id": request_id
|
80 |
}
|
|
|
94 |
|
95 |
@app.exception_handler(RequestValidationError)
|
96 |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
97 |
+
local_logger = setup_logging()
|
98 |
local_logger.error(f"exception errors: {exc.errors()}.")
|
99 |
local_logger.error(f"exception body: {exc.body}.")
|
100 |
headers = request.headers.items()
|
|
|
109 |
|
110 |
@app.exception_handler(HTTPException)
|
111 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
112 |
+
local_logger = setup_logging()
|
113 |
local_logger.error(f"exception: {str(exc)}.")
|
114 |
headers = request.headers.items()
|
115 |
local_logger.error(f'request header: {dict(headers)}.' )
|
src/prediction_api/predictor.py
CHANGED
@@ -2,21 +2,26 @@
|
|
2 |
import json
|
3 |
|
4 |
from src.utilities.constants import ROOT
|
|
|
5 |
from src.utilities.utilities import setup_logging
|
6 |
|
7 |
-
|
8 |
local_logger = setup_logging()
|
9 |
|
10 |
|
11 |
-
def base_predict(
|
|
|
|
|
12 |
from samgeo import SamGeo, tms_to_geotiff
|
13 |
|
14 |
image = f"{root_folder}/satellite.tif"
|
15 |
-
local_logger.info("start tms_to_geotiff")
|
|
|
|
|
16 |
# bbox: image input coordinate
|
17 |
tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
|
18 |
|
19 |
local_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
|
|
|
20 |
predictor = SamGeo(
|
21 |
model_type=model_name,
|
22 |
checkpoint_dir=root_folder,
|
@@ -26,10 +31,10 @@ def base_predict(bbox, point_coords, point_crs="EPSG:4326", zoom=16, model_name:
|
|
26 |
local_logger.info(f"initialized samgeo instance, start to set_image {image}...")
|
27 |
predictor.set_image(image)
|
28 |
output_name = f"{root_folder}/output.tif"
|
29 |
-
|
30 |
local_logger.info(f"done set_image, start prediction...")
|
31 |
predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
|
32 |
-
|
33 |
local_logger.info(f"done prediction, start tiff to geojson conversion...")
|
34 |
|
35 |
# geotiff to geojson
|
|
|
2 |
import json
|
3 |
|
4 |
from src.utilities.constants import ROOT
|
5 |
+
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
6 |
from src.utilities.utilities import setup_logging
|
7 |
|
|
|
8 |
local_logger = setup_logging()
|
9 |
|
10 |
|
11 |
+
def base_predict(
|
12 |
+
bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
|
13 |
+
) -> str:
|
14 |
from samgeo import SamGeo, tms_to_geotiff
|
15 |
|
16 |
image = f"{root_folder}/satellite.tif"
|
17 |
+
local_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}.")
|
18 |
+
for coord in bbox:
|
19 |
+
local_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
20 |
# bbox: image input coordinate
|
21 |
tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
|
22 |
|
23 |
local_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
|
24 |
+
|
25 |
predictor = SamGeo(
|
26 |
model_type=model_name,
|
27 |
checkpoint_dir=root_folder,
|
|
|
31 |
local_logger.info(f"initialized samgeo instance, start to set_image {image}...")
|
32 |
predictor.set_image(image)
|
33 |
output_name = f"{root_folder}/output.tif"
|
34 |
+
|
35 |
local_logger.info(f"done set_image, start prediction...")
|
36 |
predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
|
37 |
+
|
38 |
local_logger.info(f"done prediction, start tiff to geojson conversion...")
|
39 |
|
40 |
# geotiff to geojson
|
src/utilities/type_hints.py
CHANGED
@@ -11,20 +11,3 @@ ts_dict_str2 = Dict[str, str]
|
|
11 |
ts_dict_str3 = Dict[str, str, any]
|
12 |
ts_float64_1 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
13 |
ts_float64_2 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
14 |
-
|
15 |
-
"""
|
16 |
-
ts_list_str1 = List[str]
|
17 |
-
ts_http2 = Tuple[ts_list_str1, ts_list_str1]
|
18 |
-
ts_list_float2 = List[float, float]
|
19 |
-
ts_llist_float2 = List[ts_list_float2, ts_list_float2]
|
20 |
-
ts_geojson = Dict[str, str, Dict[str, Dict[str]], List[str, Dict[int], Dict[str, List]]]
|
21 |
-
ts_dict_str2b = Dict[str, any]
|
22 |
-
ts_ddict2 = Dict[str, Dict, Dict[str, List]]
|
23 |
-
ts_tuple_str2 = Tuple[str, str]
|
24 |
-
ts_tuple_arr2 = Tuple[np.ndarray, np.ndarray]
|
25 |
-
ts_tuple_flat2 = Tuple[float, float]
|
26 |
-
ts_tuple_flat4 = Tuple[float, float, float, float]
|
27 |
-
ts_list_float4 = List[float, float, float, float]
|
28 |
-
ts_tuple_int4 = Tuple[int, int, int, int]
|
29 |
-
ts_ddict3 = Dict[List[Dict[float | int | str]], Dict[float | int]]
|
30 |
-
"""
|
|
|
11 |
ts_dict_str3 = Dict[str, str, any]
|
12 |
ts_float64_1 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
13 |
ts_float64_2 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static/app.js
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const jsonBtn = document.getElementById("getJson");
|
2 |
+
const apiBtn = document.getElementById("getApi");
|
3 |
+
const output = document.getElementById("output");
|
4 |
+
const coordsForm = document.getElementById("coords-form");
|
5 |
+
|
6 |
+
function formData2json(dataId, newObj={}) {
|
7 |
+
const formData = new FormData(dataId);
|
8 |
+
formData.forEach(function(value, key){
|
9 |
+
newObj[key] = value;
|
10 |
+
});
|
11 |
+
return JSON.stringify(newObj);
|
12 |
+
}
|
13 |
+
|
14 |
+
coordsForm.addEventListener('submit', event => {
|
15 |
+
event.preventDefault();
|
16 |
+
const inputJson = formData2json(coordsForm)
|
17 |
+
console.log("inputJson", inputJson, "#");
|
18 |
+
|
19 |
+
fetch('/infer_samgeo', {
|
20 |
+
method: 'POST', // or 'PUT'
|
21 |
+
body: inputJson, // a FormData will automatically set the 'Content-Type',
|
22 |
+
headers: {"Content-Type": "application/json"},
|
23 |
+
}).then(function (response) {
|
24 |
+
return response.json();
|
25 |
+
}).then(function (data) {
|
26 |
+
console.log("data:", data, "#")
|
27 |
+
output.innerHTML = JSON.stringify(data)
|
28 |
+
}).catch(function (err) {
|
29 |
+
console.log("err:", err, "#")
|
30 |
+
output.innerHTML = `err:${JSON.stringify(err)}.`;
|
31 |
+
});
|
32 |
+
event.preventDefault();
|
33 |
+
});
|
static/index.html
CHANGED
@@ -1,24 +1,41 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<meta http-equiv="X-UA-Compatible" content="ie=edge">
|
7 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.min.css" />
|
8 |
+
<title>Fetch API</title>
|
9 |
+
</head>
|
10 |
+
<body>
|
11 |
+
<div class="container">
|
12 |
+
<h1>Fetch API</h1>
|
13 |
+
<form id="coords-form">
|
14 |
+
<label>
|
15 |
+
bbox: x1 <input type="number" id="x-center-form" name="x1" value="-122.1497"/>
|
16 |
+
</label>
|
17 |
+
<label>
|
18 |
+
bbox: x2 <input type="number" id="y-center-form" name="x2" value="37.6311"/>
|
19 |
+
</label>
|
20 |
+
<label>
|
21 |
+
bbox: y1 <input type="number" id="y-center-form" name="y1" value="-122.1203"/>
|
22 |
+
</label>
|
23 |
+
<label>
|
24 |
+
bbox: y2 <input type="number" id="y-center-form" name="y2" value="37.6458"/>
|
25 |
+
</label>
|
26 |
+
<br/><br/>
|
27 |
+
<label>
|
28 |
+
x point: <input type="number" id="x-form" name="x" value="-122.1419"/>
|
29 |
+
</label>
|
30 |
+
<label>
|
31 |
+
y point: <input type="number" id="y-form" name="y" value="37.6383"/>
|
32 |
+
</label>
|
33 |
+
<br/><br/>
|
34 |
+
<button type="submit" id="submit-btn">submit form</button>
|
35 |
+
</form>
|
36 |
+
<br><br>
|
37 |
+
<div id="output"></div>
|
38 |
+
</div>
|
39 |
+
<script src="app.js"></script>
|
40 |
+
</body>
|
41 |
+
</html>
|
static/script.js
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
const coordsGenForm = document.querySelector(".coords-gen-form");
|
2 |
-
|
3 |
-
const segmentCoords = async (coords) => {
|
4 |
-
const inferResponse = await fetch(`infer_samgeo?input=${coords}`);
|
5 |
-
const inferJson = await inferResponse.json();
|
6 |
-
|
7 |
-
return inferJson.output;
|
8 |
-
};
|
9 |
-
|
10 |
-
coordsGenForm.addEventListener("submit", async (event) => {
|
11 |
-
event.preventDefault();
|
12 |
-
|
13 |
-
const coordsGenInput = document.getElementById("coords-gen-input");
|
14 |
-
const coordsGenParagraph = document.querySelector(".coords-gen-output");
|
15 |
-
|
16 |
-
coordsGenParagraph.coordsContent = await segmentCoords(coordsGenInput.value);
|
17 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|