File size: 3,739 Bytes
7d48fbe
 
 
c311b69
fcb8c81
 
7d48fbe
fcb8c81
 
 
7d48fbe
 
0914710
 
 
c311b69
0914710
 
 
c311b69
 
 
0914710
 
 
 
 
 
 
 
 
 
 
 
 
 
c311b69
0914710
 
 
 
7d48fbe
 
5350122
0914710
 
 
 
 
 
 
 
 
 
 
 
 
7d48fbe
c311b69
7d48fbe
 
 
 
 
 
 
 
 
0914710
c311b69
 
 
0914710
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from datetime import datetime

from lisa_on_cuda.utils import app_helpers
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
from samgis_lisa_on_cuda import app_logger
from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
from samgis_lisa_on_cuda.io.raster_helpers import write_raster_png, write_raster_tiff
from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
from samgis_lisa_on_cuda.prediction_api.global_models import models_dict
from samgis_lisa_on_cuda.utilities.constants import DEFAULT_URL_TILES

msg_write_tmp_on_disk = "found option to write images and geojson output..."


def lisa_predict(
        bbox: LlistFloat,
        prompt: str,
        zoom: float,
        inference_function_name_key: str = "lisa",
        source: str = DEFAULT_URL_TILES,
        source_name: str = None
) -> DictStrInt:
    """
    Return predictions as a geojson from a geo-referenced image using the given input prompt.

    1. if necessary instantiate a segment anything machine learning instance model
    2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
    3. get a prediction image from the segment anything instance model using the input prompt
    4. get a geo-referenced geojson from the prediction image

    Args:
        bbox: coordinates bounding box
        prompt: machine learning input prompt
        zoom: Level of detail
        inference_function_name_key: machine learning model name
        source: xyz
        source_name: name of tile provider

    Returns:
        Affine transform
    """
    from os import getenv

    app_logger.info("start lisa inference...")
    if models_dict[inference_function_name_key]["inference"] is None:
        app_logger.info(f"missing inference function {inference_function_name_key}, instantiating it now!")
        parsed_args = app_helpers.parse_args([])
        inference_fn = app_helpers.get_inference_model_by_args(parsed_args)
        models_dict[inference_function_name_key]["inference"] = inference_fn
    app_logger.debug(f"using a {inference_function_name_key} instance model...")
    inference_fn = models_dict[inference_function_name_key]["inference"]

    pt0, pt1 = bbox
    app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
    img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
    app_logger.info(
        f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
    folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
    prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
    if bool(folder_write_tmp_on_disk):
        now = datetime.now().strftime('%Y%m%d_%H%M%S')
        app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
        if img.shape and len(img.shape) == 2:
            write_raster_tiff(img, transform, f"{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
        if img.shape and len(img.shape) == 3 and img.shape[2] == 3:
            write_raster_png(img, transform, f"{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk)
    else:
        app_logger.info("keep all temp data in memory...")

    app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
    embedding_key = f"{source_name}_z{zoom}_{prefix}"
    _, mask, output_string = inference_fn(prompt, img, app_logger, embedding_key)
    # app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
    return {
        "output_string": output_string,
        **get_vectorized_raster_as_geojson(mask, transform)
    }