|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Callable, NoReturn |
|
|
|
from asgi_correlation_id import CorrelationIdMiddleware |
|
import gradio as gr |
|
import spaces |
|
from starlette.responses import JSONResponse |
|
import structlog |
|
import uvicorn |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI, HTTPException, Request, status |
|
from fastapi.exceptions import RequestValidationError |
|
from fastapi.responses import FileResponse, HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from pydantic import ValidationError |
|
from samgis_core.utilities import create_folders_if_not_exists |
|
from samgis_core.utilities import frontend_builder |
|
from samgis_core.utilities.session_logger import setup_logging |
|
from samgis_web.utilities.constants import GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN, GRADIO_EXAMPLE_BODY_STRING_PROMPT |
|
from samgis_web.utilities.type_hints import StringPromptApiRequestBody |
|
|
|
|
|
load_dotenv() |
|
project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent |
|
workdir = Path(os.getenv("WORKDIR", project_root_folder)) |
|
model_folder = Path(project_root_folder / "machine_learning_models") |
|
|
|
log_level = os.getenv("LOG_LEVEL", "INFO") |
|
setup_logging(log_level=log_level) |
|
app_logger = structlog.stdlib.get_logger() |
|
app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.") |
|
|
|
folders_map = os.getenv("FOLDERS_MAP", "{}") |
|
markdown_text = os.getenv("MARKDOWN_TEXT", "") |
|
examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n") |
|
example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}")) |
|
mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", "")) |
|
|
|
static_dist_folder = workdir / "static" / "dist" |
|
input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css") |
|
vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio") |
|
vite_index_url = os.getenv("VITE_INDEX_URL", "/") |
|
vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis") |
|
vite_lisa_url = os.getenv("VITE_LISA_URL", "/lisa") |
|
fastapi_title = "samgis-lisa-on-zero2" |
|
app = FastAPI(title=fastapi_title, version="1.0") |
|
|
|
|
|
@app.middleware("http") |
|
async def request_middleware(request, call_next): |
|
from samgis_web.web.middlewares import logging_middleware |
|
|
|
return await logging_middleware(request, call_next) |
|
|
|
@spaces.GPU |
|
def gpu_initialization() -> None: |
|
app_logger.info("GPU initialization...") |
|
|
|
|
|
def get_example_complete(example_text): |
|
example_dict = dict(**GRADIO_EXAMPLE_BODY_STRING_PROMPT) |
|
example_dict["string_prompt"] = example_text |
|
return json.dumps(example_dict) |
|
|
|
|
|
def get_gradio_interface_geojson(fn_inference: Callable): |
|
with gr.Blocks() as gradio_app: |
|
gr.Markdown(GRADIO_MARKDOWN) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(lines=1, placeholder=None, label="Payload input") |
|
btn = gr.Button(value="Submit") |
|
with gr.Column(): |
|
text_output = gr.Textbox(lines=1, placeholder=None, label="Geojson Output") |
|
|
|
gr.Examples( |
|
examples=[ |
|
get_example_complete(example) for example in GRADIO_EXAMPLES_TEXT_LIST |
|
], |
|
inputs=[text_input], |
|
) |
|
btn.click( |
|
fn_inference, |
|
inputs=[text_input], |
|
outputs=[text_output] |
|
) |
|
return gradio_app |
|
|
|
|
|
def handle_exception_response(exception: Exception) -> NoReturn: |
|
import subprocess |
|
project_root_folder_content = subprocess.run( |
|
f"ls -l {project_root_folder}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") |
|
workdir_folder_content = subprocess.run( |
|
f"ls -l {workdir}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.") |
|
app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.") |
|
app_logger.error(f"inference error:{exception}.") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference" |
|
) |
|
|
|
|
|
@app.get("/health") |
|
async def health() -> JSONResponse: |
|
from samgis_web.__version__ import __version__ as version_web |
|
from samgis_core.__version__ import __version__ as version_core |
|
from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda |
|
from samgis_lisa.__version__ import __version__ as version_samgis_lisa |
|
|
|
app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.") |
|
app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.") |
|
return JSONResponse(status_code=200, content={"msg": "still alive..."}) |
|
|
|
|
|
def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str: |
|
from samgis_lisa.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt |
|
from samgis_lisa.prediction_api import lisa |
|
from samgis_lisa.utilities.constants import LISA_INFERENCE_FN |
|
|
|
app_logger.info("starting lisa inference request...") |
|
|
|
try: |
|
import time |
|
|
|
time_start_run = time.time() |
|
body_request = get_parsed_bbox_points_with_string_prompt(request_input) |
|
app_logger.info(f"lisa body_request:{body_request}.") |
|
try: |
|
source = body_request["source"] |
|
source_name = body_request["source_name"] |
|
app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.") |
|
app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.") |
|
app_logger.debug(f"lisa module:{lisa}.") |
|
gpu_initialization() |
|
output = lisa.lisa_predict( |
|
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], |
|
source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN |
|
) |
|
duration_run = time.time() - time_start_run |
|
app_logger.info(f"duration_run:{duration_run}.") |
|
body = { |
|
"duration_run": duration_run, |
|
"output": output |
|
} |
|
dumped = json.dumps(body) |
|
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") |
|
app_logger.debug(f"complete json.dumps(body):{dumped}.") |
|
return dumped |
|
except Exception as inference_exception: |
|
app_logger.error(f"inference_exception:{inference_exception}.") |
|
app_logger.error(f"inference_exception, request_input:{request_input}.") |
|
raise HTTPException(status_code=500, detail="Internal Server Error") |
|
except ValidationError as va1: |
|
app_logger.error(f"validation error: {str(va1)}.") |
|
app_logger.error(f"ValidationError, request_input:{request_input}.") |
|
raise RequestValidationError("Unprocessable Entity") |
|
|
|
|
|
@app.post("/infer_lisa") |
|
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: |
|
dumped = infer_lisa_gradio(request_input=request_input) |
|
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") |
|
app_logger.debug(f"complete json.dumps(body):{dumped}.") |
|
return JSONResponse(status_code=200, content={"body": dumped}) |
|
|
|
|
|
@app.exception_handler(RequestValidationError) |
|
def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: |
|
from samgis_web.web import exception_handlers |
|
|
|
return exception_handlers.request_validation_exception_handler(request, exc) |
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: |
|
from samgis_web.web import exception_handlers |
|
|
|
return exception_handlers.http_exception_handler(request, exc) |
|
|
|
|
|
create_folders_if_not_exists.folders_creation(folders_map) |
|
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") |
|
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") |
|
if bool(write_tmp_on_disk): |
|
try: |
|
assert Path(write_tmp_on_disk).is_dir() |
|
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") |
|
templates = Jinja2Templates(directory=str(project_root_folder / "static")) |
|
|
|
@app.get("/vis_output", response_class=HTMLResponse) |
|
def list_files(request: Request): |
|
|
|
files = os.listdir(write_tmp_on_disk) |
|
files_paths = sorted([f"{request.url._url}/{f}" for f in files]) |
|
print(files_paths) |
|
return templates.TemplateResponse( |
|
"list_files.html", {"request": request, "files": files_paths} |
|
) |
|
except (AssertionError, RuntimeError) as rerr: |
|
app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") |
|
raise rerr |
|
|
|
frontend_builder.build_frontend( |
|
project_root_folder=workdir, |
|
input_css_path=input_css_path, |
|
output_dist_folder=static_dist_folder |
|
) |
|
app_logger.info("build_frontend ok!") |
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") |
|
|
|
|
|
app.mount(vite_samgis_url, StaticFiles(directory=static_dist_folder, html=True), name="samgis") |
|
|
|
|
|
@app.get(vite_samgis_url) |
|
async def samgis() -> FileResponse: |
|
return FileResponse(path=str(static_dist_folder / "samgis.html"), media_type="text/html") |
|
|
|
|
|
|
|
app.mount(vite_lisa_url, StaticFiles(directory=static_dist_folder, html=True), name="lisa") |
|
|
|
|
|
@app.get(vite_lisa_url) |
|
async def lisa() -> FileResponse: |
|
return FileResponse(path=str(static_dist_folder / "lisa.html"), media_type="text/html") |
|
|
|
|
|
|
|
app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index") |
|
|
|
|
|
@app.get(vite_index_url) |
|
async def index() -> FileResponse: |
|
return FileResponse(path=str(static_dist_folder / "index.html"), media_type="text/html") |
|
|
|
|
|
app_logger.info(f"creating gradio interface...") |
|
gr_interface = get_gradio_interface_geojson(infer_lisa_gradio) |
|
app_logger.info(f"gradio interface created, mounting gradio app on url {vite_gradio_url} within FastAPI...") |
|
app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url) |
|
app_logger.info("mounted gradio app within fastapi") |
|
|
|
|
|
app.add_middleware(CorrelationIdMiddleware) |
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
uvicorn.run(host="0.0.0.0", port=7860, app=app) |
|
except Exception as ex: |
|
app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") |
|
print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") |
|
raise ex |
|
|