[feat] reformat code, log context request id
Browse files- dockerfiles/dockerfile-base-webserver +1 -1
- requirements.txt +1 -0
- src/__init__.py +4 -0
- src/main.py +39 -27
- src/prediction_api/predictor.py +12 -13
- src/utilities/measures.py +0 -76
- src/utilities/serialize.py +0 -86
- src/utilities/type_hints.py +2 -9
- src/utilities/utilities.py +6 -15
dockerfiles/dockerfile-base-webserver
CHANGED
@@ -4,7 +4,7 @@ WORKDIR /code
|
|
4 |
|
5 |
RUN which python
|
6 |
RUN python --version
|
7 |
-
RUN python -m pip install --no-cache-dir fastapi uvicorn
|
8 |
|
9 |
RUN useradd -m -u 1000 user
|
10 |
|
|
|
4 |
|
5 |
RUN which python
|
6 |
RUN python --version
|
7 |
+
RUN python -m pip install --no-cache-dir fastapi uvicorn loguru
|
8 |
|
9 |
RUN useradd -m -u 1000 user
|
10 |
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
fastapi
|
2 |
bson
|
|
|
3 |
python-dotenv
|
4 |
segment-geospatial
|
5 |
uvicorn[standard]
|
|
|
1 |
fastapi
|
2 |
bson
|
3 |
+
loguru
|
4 |
python-dotenv
|
5 |
segment-geospatial
|
6 |
uvicorn[standard]
|
src/__init__.py
CHANGED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utilities.utilities import setup_logging
|
2 |
+
|
3 |
+
|
4 |
+
app_logger = setup_logging(debug=True)
|
src/main.py
CHANGED
@@ -1,18 +1,38 @@
|
|
1 |
import json
|
2 |
-
|
|
|
3 |
from fastapi import FastAPI, HTTPException, Request, status
|
4 |
from fastapi.exceptions import RequestValidationError
|
5 |
from fastapi.responses import FileResponse, JSONResponse
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
from pydantic import BaseModel
|
8 |
|
|
|
9 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
10 |
-
from src.utilities.utilities import setup_logging
|
11 |
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class Input(BaseModel):
|
17 |
x1: float
|
18 |
y1: float
|
@@ -39,6 +59,8 @@ def get_parsed_bbox_points(request_input: Input) -> BBoxWithPoint:
|
|
39 |
|
40 |
@app.post("/post_test")
|
41 |
async def post_test(request_input: Input) -> JSONResponse:
|
|
|
|
|
42 |
return JSONResponse(
|
43 |
status_code=200,
|
44 |
content=get_parsed_bbox_points(request_input)
|
@@ -47,6 +69,7 @@ async def post_test(request_input: Input) -> JSONResponse:
|
|
47 |
|
48 |
@app.get("/hello")
|
49 |
async def hello() -> JSONResponse:
|
|
|
50 |
return JSONResponse(status_code=200, content={"msg": "hello"})
|
51 |
|
52 |
|
@@ -56,51 +79,41 @@ def samgeo(request_input: Input):
|
|
56 |
|
57 |
from src.prediction_api.predictor import base_predict
|
58 |
|
59 |
-
|
60 |
-
local_logger.info("starting inference request...")
|
61 |
|
62 |
try:
|
63 |
import time
|
64 |
|
65 |
time_start_run = time.time()
|
66 |
-
# debug = True
|
67 |
-
# local_logger = setup_logging(debug)
|
68 |
request_body = get_parsed_bbox_points(request_input)
|
69 |
-
|
70 |
try:
|
71 |
output = base_predict(
|
72 |
bbox=request_body["bbox"],
|
73 |
point_coords=request_body["point"]
|
74 |
)
|
75 |
-
|
76 |
-
|
77 |
-
body = {
|
78 |
-
"duration_run": duration_run,
|
79 |
-
# "request_id": request_id
|
80 |
-
}
|
81 |
-
local_logger.info(f"body:{body}.")
|
82 |
body["output"] = output
|
83 |
-
# local_logger.info(f"End_request::{request_id}...")
|
84 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
85 |
except Exception as inference_exception:
|
86 |
home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
|
87 |
-
|
88 |
-
|
89 |
return HTTPException(status_code=500, detail="Internal server error on inference")
|
90 |
except Exception as generic_exception:
|
91 |
-
|
92 |
return HTTPException(status_code=500, detail="Generic internal server error")
|
93 |
|
94 |
|
95 |
@app.exception_handler(RequestValidationError)
|
96 |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
97 |
-
|
98 |
-
|
99 |
-
local_logger.error(f"exception body: {exc.body}.")
|
100 |
headers = request.headers.items()
|
101 |
-
|
102 |
params = request.query_params.items()
|
103 |
-
|
104 |
return JSONResponse(
|
105 |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
106 |
content={"msg": "Error - Unprocessable Entity"}
|
@@ -109,12 +122,11 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|
109 |
|
110 |
@app.exception_handler(HTTPException)
|
111 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
112 |
-
|
113 |
-
local_logger.error(f"exception: {str(exc)}.")
|
114 |
headers = request.headers.items()
|
115 |
-
|
116 |
params = request.query_params.items()
|
117 |
-
|
118 |
return JSONResponse(
|
119 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
120 |
content={"msg": "Error - Internal Server Error"}
|
|
|
1 |
import json
|
2 |
+
import uuid
|
3 |
+
|
4 |
from fastapi import FastAPI, HTTPException, Request, status
|
5 |
from fastapi.exceptions import RequestValidationError
|
6 |
from fastapi.responses import FileResponse, JSONResponse
|
7 |
from fastapi.staticfiles import StaticFiles
|
8 |
from pydantic import BaseModel
|
9 |
|
10 |
+
from src import app_logger
|
11 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
|
|
12 |
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
|
17 |
+
@app.middleware("http")
|
18 |
+
async def request_middleware(request, call_next):
|
19 |
+
request_id = str(uuid.uuid4())
|
20 |
+
with app_logger.contextualize(request_id=request_id):
|
21 |
+
app_logger.info("Request started")
|
22 |
+
|
23 |
+
try:
|
24 |
+
response = await call_next(request)
|
25 |
+
|
26 |
+
except Exception as ex:
|
27 |
+
app_logger.error(f"Request failed: {ex}")
|
28 |
+
response = JSONResponse(content={"success": False}, status_code=500)
|
29 |
+
|
30 |
+
finally:
|
31 |
+
response.headers["X-Request-ID"] = request_id
|
32 |
+
app_logger.info(f"Request ended")
|
33 |
+
return response
|
34 |
+
|
35 |
+
|
36 |
class Input(BaseModel):
|
37 |
x1: float
|
38 |
y1: float
|
|
|
59 |
|
60 |
@app.post("/post_test")
|
61 |
async def post_test(request_input: Input) -> JSONResponse:
|
62 |
+
request_body = get_parsed_bbox_points(request_input)
|
63 |
+
app_logger.info(f"request_body:{request_body}.")
|
64 |
return JSONResponse(
|
65 |
status_code=200,
|
66 |
content=get_parsed_bbox_points(request_input)
|
|
|
69 |
|
70 |
@app.get("/hello")
|
71 |
async def hello() -> JSONResponse:
|
72 |
+
app_logger.info(f"hello")
|
73 |
return JSONResponse(status_code=200, content={"msg": "hello"})
|
74 |
|
75 |
|
|
|
79 |
|
80 |
from src.prediction_api.predictor import base_predict
|
81 |
|
82 |
+
app_logger.info("starting inference request...")
|
|
|
83 |
|
84 |
try:
|
85 |
import time
|
86 |
|
87 |
time_start_run = time.time()
|
|
|
|
|
88 |
request_body = get_parsed_bbox_points(request_input)
|
89 |
+
app_logger.info(f"request_body:{request_body}.")
|
90 |
try:
|
91 |
output = base_predict(
|
92 |
bbox=request_body["bbox"],
|
93 |
point_coords=request_body["point"]
|
94 |
)
|
95 |
+
body = {"duration_run": time.time() - time_start_run}
|
96 |
+
app_logger.info(f"body:{body}.")
|
|
|
|
|
|
|
|
|
|
|
97 |
body["output"] = output
|
|
|
98 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
99 |
except Exception as inference_exception:
|
100 |
home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
|
101 |
+
app_logger.error(f"/home/user ls -l: {home_content.stdout}.")
|
102 |
+
app_logger.error(f"inference error:{inference_exception}.")
|
103 |
return HTTPException(status_code=500, detail="Internal server error on inference")
|
104 |
except Exception as generic_exception:
|
105 |
+
app_logger.error(f"generic error:{generic_exception}.")
|
106 |
return HTTPException(status_code=500, detail="Generic internal server error")
|
107 |
|
108 |
|
109 |
@app.exception_handler(RequestValidationError)
|
110 |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
111 |
+
app_logger.error(f"exception errors: {exc.errors()}.")
|
112 |
+
app_logger.error(f"exception body: {exc.body}.")
|
|
|
113 |
headers = request.headers.items()
|
114 |
+
app_logger.error(f'request header: {dict(headers)}.' )
|
115 |
params = request.query_params.items()
|
116 |
+
app_logger.error(f'request query params: {dict(params)}.')
|
117 |
return JSONResponse(
|
118 |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
119 |
content={"msg": "Error - Unprocessable Entity"}
|
|
|
122 |
|
123 |
@app.exception_handler(HTTPException)
|
124 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
125 |
+
app_logger.error(f"exception: {str(exc)}.")
|
|
|
126 |
headers = request.headers.items()
|
127 |
+
app_logger.error(f'request header: {dict(headers)}.' )
|
128 |
params = request.query_params.items()
|
129 |
+
app_logger.error(f'request query params: {dict(params)}.')
|
130 |
return JSONResponse(
|
131 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
132 |
content={"msg": "Error - Internal Server Error"}
|
src/prediction_api/predictor.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
# Press the green button in the gutter to run the script.
|
2 |
import json
|
3 |
|
|
|
4 |
from src.utilities.constants import ROOT
|
5 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
6 |
-
from src.utilities.utilities import setup_logging
|
7 |
-
|
8 |
-
local_logger = setup_logging()
|
9 |
|
10 |
|
11 |
def base_predict(
|
@@ -14,13 +12,13 @@ def base_predict(
|
|
14 |
from samgeo import SamGeo, tms_to_geotiff
|
15 |
|
16 |
image = f"{root_folder}/satellite.tif"
|
17 |
-
|
18 |
for coord in bbox:
|
19 |
-
|
20 |
# bbox: image input coordinate
|
21 |
tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
|
22 |
|
23 |
-
|
24 |
|
25 |
predictor = SamGeo(
|
26 |
model_type=model_name,
|
@@ -28,21 +26,22 @@ def base_predict(
|
|
28 |
automatic=False,
|
29 |
sam_kwargs=None,
|
30 |
)
|
31 |
-
|
32 |
predictor.set_image(image)
|
33 |
output_name = f"{root_folder}/output.tif"
|
34 |
|
35 |
-
|
36 |
predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
|
37 |
|
38 |
-
|
39 |
|
40 |
# geotiff to geojson
|
41 |
vector = f"{root_folder}/feats.geojson"
|
42 |
predictor.tiff_to_geojson(output_name, vector, bidx=1)
|
43 |
-
|
44 |
|
45 |
with open(vector) as out_gdf:
|
46 |
-
out_gdf_str =
|
47 |
-
|
48 |
-
|
|
|
|
1 |
# Press the green button in the gutter to run the script.
|
2 |
import json
|
3 |
|
4 |
+
from src import app_logger
|
5 |
from src.utilities.constants import ROOT
|
6 |
from src.utilities.type_hints import input_floatlist, input_floatlist2
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def base_predict(
|
|
|
12 |
from samgeo import SamGeo, tms_to_geotiff
|
13 |
|
14 |
image = f"{root_folder}/satellite.tif"
|
15 |
+
app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}.")
|
16 |
for coord in bbox:
|
17 |
+
app_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
18 |
# bbox: image input coordinate
|
19 |
tms_to_geotiff(output=image, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
|
20 |
|
21 |
+
app_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
|
22 |
|
23 |
predictor = SamGeo(
|
24 |
model_type=model_name,
|
|
|
26 |
automatic=False,
|
27 |
sam_kwargs=None,
|
28 |
)
|
29 |
+
app_logger.info(f"initialized samgeo instance, start to set_image {image}...")
|
30 |
predictor.set_image(image)
|
31 |
output_name = f"{root_folder}/output.tif"
|
32 |
|
33 |
+
app_logger.info(f"done set_image, start prediction...")
|
34 |
predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=output_name)
|
35 |
|
36 |
+
app_logger.info(f"done prediction, start tiff to geojson conversion...")
|
37 |
|
38 |
# geotiff to geojson
|
39 |
vector = f"{root_folder}/feats.geojson"
|
40 |
predictor.tiff_to_geojson(output_name, vector, bidx=1)
|
41 |
+
app_logger.info(f"start reading geojson...")
|
42 |
|
43 |
with open(vector) as out_gdf:
|
44 |
+
out_gdf_str = out_gdf.read()
|
45 |
+
out_gdf_json = json.loads(out_gdf_str)
|
46 |
+
app_logger.info(f"geojson string output length:{len(out_gdf_str)}.")
|
47 |
+
return out_gdf_json
|
src/utilities/measures.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
"""helpers for compute measures: hash, time benchmarks"""
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
|
5 |
-
def hash_calculate(arr: any, debug: bool = False) -> str or bytes:
|
6 |
-
"""
|
7 |
-
Return computed hash from input variable (typically a numpy array).
|
8 |
-
|
9 |
-
Args:
|
10 |
-
arr: input variable
|
11 |
-
debug: logging debug argument
|
12 |
-
|
13 |
-
Returns:
|
14 |
-
str or bytes: computed hash from input variable
|
15 |
-
|
16 |
-
"""
|
17 |
-
import hashlib
|
18 |
-
import numpy as np
|
19 |
-
from base64 import b64encode
|
20 |
-
|
21 |
-
from src.utilities.utilities import setup_logging
|
22 |
-
local_logger = setup_logging(debug)
|
23 |
-
|
24 |
-
if isinstance(arr, np.ndarray):
|
25 |
-
hash_fn = hashlib.sha256(arr.data)
|
26 |
-
elif isinstance(arr, dict):
|
27 |
-
import json
|
28 |
-
from src.utilities.serialize import serialize
|
29 |
-
|
30 |
-
serialized = serialize(arr)
|
31 |
-
variable_to_hash = json.dumps(serialized, sort_keys=True).encode('utf-8')
|
32 |
-
hash_fn = hashlib.sha256(variable_to_hash)
|
33 |
-
elif isinstance(arr, str):
|
34 |
-
try:
|
35 |
-
hash_fn = hashlib.sha256(arr)
|
36 |
-
except TypeError:
|
37 |
-
local_logger.warning(f"TypeError, re-try encoding arg:{arr},type:{type(arr)}.")
|
38 |
-
hash_fn = hashlib.sha256(arr.encode('utf-8'))
|
39 |
-
elif isinstance(arr, bytes):
|
40 |
-
hash_fn = hashlib.sha256(arr)
|
41 |
-
else:
|
42 |
-
raise ValueError(f"variable 'arr':{arr} not yet handled.")
|
43 |
-
return b64encode(hash_fn.digest())
|
44 |
-
|
45 |
-
|
46 |
-
def sha256sum(filename: Path or str) -> str:
|
47 |
-
"""
|
48 |
-
Return computed hash for input file.
|
49 |
-
|
50 |
-
Args:
|
51 |
-
filename: input variable
|
52 |
-
|
53 |
-
Returns:
|
54 |
-
str: computed hash
|
55 |
-
|
56 |
-
"""
|
57 |
-
import hashlib
|
58 |
-
import mmap
|
59 |
-
|
60 |
-
h = hashlib.sha256()
|
61 |
-
with open(filename, 'rb') as f:
|
62 |
-
with mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) as mm:
|
63 |
-
h.update(mm)
|
64 |
-
return h.hexdigest()
|
65 |
-
|
66 |
-
|
67 |
-
def perf_counter() -> float:
|
68 |
-
"""
|
69 |
-
Performance counter for benchmarking.
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
float: computed time value at execution time
|
73 |
-
|
74 |
-
"""
|
75 |
-
import time
|
76 |
-
return time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utilities/serialize.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
"""Serialize objects"""
|
2 |
-
from typing import Mapping
|
3 |
-
|
4 |
-
from src.utilities.type_hints import ts_dict_str3, ts_dict_str2
|
5 |
-
|
6 |
-
|
7 |
-
def serialize(obj:any, include_none:bool=False) -> object:
|
8 |
-
"""
|
9 |
-
Return the input object into a serializable one
|
10 |
-
|
11 |
-
Args:
|
12 |
-
obj: Object to serialize
|
13 |
-
include_none: bool to indicate if include also keys with None values during dict serialization
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
object: serialized object
|
17 |
-
|
18 |
-
"""
|
19 |
-
return _serialize(obj, include_none)
|
20 |
-
|
21 |
-
|
22 |
-
def _serialize(obj:any, include_none:bool) -> any:
|
23 |
-
import numpy as np
|
24 |
-
|
25 |
-
primitive = (int, float, str, bool)
|
26 |
-
# print(type(obj))
|
27 |
-
try:
|
28 |
-
if obj is None:
|
29 |
-
return None
|
30 |
-
elif isinstance(obj, np.integer):
|
31 |
-
return int(obj)
|
32 |
-
elif isinstance(obj, np.floating):
|
33 |
-
return float(obj)
|
34 |
-
elif isinstance(obj, np.ndarray):
|
35 |
-
return obj.tolist()
|
36 |
-
elif isinstance(obj, primitive):
|
37 |
-
return obj
|
38 |
-
elif type(obj) is list:
|
39 |
-
return _serialize_list(obj, include_none)
|
40 |
-
elif type(obj) is tuple:
|
41 |
-
return list(obj)
|
42 |
-
elif type(obj) is bytes:
|
43 |
-
return _serialize_bytes(obj)
|
44 |
-
elif isinstance(obj, Exception):
|
45 |
-
return _serialize_exception(obj)
|
46 |
-
# elif isinstance(obj, object):
|
47 |
-
# return _serialize_object(obj, include_none)
|
48 |
-
else:
|
49 |
-
return _serialize_object(obj, include_none)
|
50 |
-
except Exception as e_serialize:
|
51 |
-
from src.utilities.utilities import setup_logging
|
52 |
-
serialize_logger = setup_logging()
|
53 |
-
serialize_logger.error(f"e_serialize::{e_serialize}, type_obj:{type(obj)}, obj:{obj}.")
|
54 |
-
return f"object_name:{str(obj)}__object_type_str:{str(type(obj))}."
|
55 |
-
|
56 |
-
|
57 |
-
def _serialize_object(obj:Mapping[any, object], include_none:bool) -> dict[any]:
|
58 |
-
from bson import ObjectId
|
59 |
-
|
60 |
-
res = {}
|
61 |
-
if type(obj) is not dict:
|
62 |
-
keys = [i for i in obj.__dict__.keys() if (getattr(obj, i) is not None) or include_none]
|
63 |
-
else:
|
64 |
-
keys = [i for i in obj.keys() if (obj[i] is not None) or include_none]
|
65 |
-
for key in keys:
|
66 |
-
if type(obj) is not dict:
|
67 |
-
res[key] = _serialize(getattr(obj, key), include_none)
|
68 |
-
elif isinstance(obj[key], ObjectId):
|
69 |
-
continue
|
70 |
-
else:
|
71 |
-
res[key] = _serialize(obj[key], include_none)
|
72 |
-
return res
|
73 |
-
|
74 |
-
|
75 |
-
def _serialize_list(ls:list, include_none:bool) -> list:
|
76 |
-
return [_serialize(elem, include_none) for elem in ls]
|
77 |
-
|
78 |
-
|
79 |
-
def _serialize_bytes(b:bytes) -> ts_dict_str2:
|
80 |
-
import base64
|
81 |
-
encoded = base64.b64encode(b)
|
82 |
-
return {"value": encoded.decode('ascii'), "type": "bytes"}
|
83 |
-
|
84 |
-
|
85 |
-
def _serialize_exception(e: Exception) -> ts_dict_str3:
|
86 |
-
return {"msg": str(e), "type": str(type(e)), **e.__dict__}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utilities/type_hints.py
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
"""custom type hints"""
|
2 |
-
from typing import List
|
3 |
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
# ts_ddict1, ts_float64_1, ts_float64_2, ts_dict_str3, ts_dict_str2
|
7 |
input_floatlist = List[float]
|
8 |
input_floatlist2 = List[input_floatlist]
|
9 |
-
|
10 |
-
ts_dict_str2 = Dict[str, str]
|
11 |
-
ts_dict_str3 = Dict[str, str, any]
|
12 |
-
ts_float64_1 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
13 |
-
ts_float64_2 = Tuple[np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64]
|
|
|
1 |
"""custom type hints"""
|
2 |
+
from typing import List
|
3 |
|
|
|
|
|
|
|
4 |
input_floatlist = List[float]
|
5 |
input_floatlist2 = List[input_floatlist]
|
6 |
+
|
|
|
|
|
|
|
|
src/utilities/utilities.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
"""Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
|
2 |
-
import
|
3 |
|
4 |
from src.utilities.constants import ROOT
|
5 |
|
6 |
|
7 |
-
def setup_logging(debug: bool = False, formatter: str =
|
8 |
"""
|
9 |
Create a logging instance with log string formatter.
|
10 |
|
@@ -16,21 +16,12 @@ def setup_logging(debug: bool = False, formatter: str = '%(asctime)s - %(name)s
|
|
16 |
Logger
|
17 |
|
18 |
"""
|
19 |
-
import logging
|
20 |
import sys
|
21 |
|
22 |
-
logger =
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
h = logging.StreamHandler(sys.stdout)
|
27 |
-
|
28 |
-
h.setFormatter(logging.Formatter(formatter))
|
29 |
-
logger.addHandler(h)
|
30 |
-
logger.setLevel(logging.INFO)
|
31 |
-
|
32 |
-
if debug:
|
33 |
-
logger.setLevel(logging.DEBUG)
|
34 |
logger.debug(f"type_logger:{type(logger)}.")
|
35 |
return logger
|
36 |
|
|
|
1 |
"""Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
|
2 |
+
import loguru
|
3 |
|
4 |
from src.utilities.constants import ROOT
|
5 |
|
6 |
|
7 |
+
def setup_logging(debug: bool = False, formatter: str = "{time} - {level} - ({extra[request_id]}) {message} ") -> loguru.logger:
|
8 |
"""
|
9 |
Create a logging instance with log string formatter.
|
10 |
|
|
|
16 |
Logger
|
17 |
|
18 |
"""
|
|
|
19 |
import sys
|
20 |
|
21 |
+
logger = loguru.logger
|
22 |
+
logger.remove()
|
23 |
+
level_logger = "DEBUG" if debug else "INFO"
|
24 |
+
logger.add(sys.stdout, format=formatter, level=level_logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
logger.debug(f"type_logger:{type(logger)}.")
|
26 |
return logger
|
27 |
|