File size: 6,671 Bytes
0914710 fcb8c81 0914710 fcb8c81 0914710 fcb8c81 0914710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import json
import time
import unittest
from http import HTTPStatus
from unittest.mock import patch
from samgis_lisa_on_cuda.io import wrappers_helpers
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_parsed_request_body, get_response
from samgis_lisa_on_cuda.utilities.type_hints import ApiRequestBody
from tests import TEST_EVENTS_FOLDER
class WrappersHelpersTest(unittest.TestCase):
@patch.object(time, "time")
def test_get_response(self, time_mocked):
time_diff = 108
end_run = 1000
time_mocked.return_value = end_run
start_time = end_run - time_diff
aws_request_id = "test_invoke_id"
with open(TEST_EVENTS_FOLDER / "get_response.json") as tst_json:
inputs_outputs = json.load(tst_json)
response_type = "200"
body_response = inputs_outputs[response_type]["input"]
output = get_response(HTTPStatus.OK.value, start_time, aws_request_id, body_response)
assert json.loads(output) == inputs_outputs[response_type]["output"]
response_type = "400"
response_400 = get_response(HTTPStatus.BAD_REQUEST.value, start_time, aws_request_id, {})
assert response_400 == inputs_outputs[response_type]["output"]
response_type = "422"
response_422 = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, aws_request_id, {})
assert response_422 == inputs_outputs[response_type]["output"]
response_type = "500"
response_500 = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, aws_request_id, {})
assert response_500 == inputs_outputs[response_type]["output"]
@staticmethod
def test_get_parsed_bbox_points():
with open(TEST_EVENTS_FOLDER / "get_parsed_bbox_prompts_single_point.json") as tst_json:
inputs_outputs = json.load(tst_json)
for k, input_output in inputs_outputs.items():
print(f"k:{k}.")
raw_body = get_parsed_request_body(**input_output["input"])
output = get_parsed_bbox_points_with_dictlist_prompt(raw_body)
assert output == input_output["output"]
@staticmethod
def test_get_parsed_bbox_other_inputs():
for json_filename in ["single_rectangle", "multi_prompt"]:
with open(TEST_EVENTS_FOLDER / f"get_parsed_bbox_prompts_{json_filename}.json") as tst_json:
inputs_outputs = json.load(tst_json)
parsed_input = ApiRequestBody.model_validate(inputs_outputs["input"])
output = get_parsed_bbox_points_with_dictlist_prompt(parsed_input)
assert output == inputs_outputs["output"]
@staticmethod
def test_get_parsed_request_body():
from samgis_core.utilities.utilities import base64_encode
input_event = {
"event": {
"bbox": {
"ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
"sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
},
"prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
"zoom": 10, "source_type": "OpenStreetMap.Mapnik", "debug": True
}
}
expected_output_dict = {
"bbox": {
"ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
"sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
},
"prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
"zoom": 10, "source_type": "OpenStreetMap.Mapnik", "debug": True
}
output = get_parsed_request_body(input_event["event"])
assert output == ApiRequestBody.model_validate(input_event["event"])
input_event_str = json.dumps(input_event["event"])
output = get_parsed_request_body(input_event_str)
assert output == ApiRequestBody.model_validate(expected_output_dict)
event = {"body": base64_encode(input_event_str).decode("utf-8")}
output = get_parsed_request_body(event)
assert output == ApiRequestBody.model_validate(expected_output_dict)
@patch.object(wrappers_helpers, "providers")
def test_get_url_tile(self, providers_mocked):
import xyzservices
from samgis_lisa_on_cuda.io.wrappers_helpers import get_url_tile
from tests import LOCAL_URL_TILE
local_tile_provider = xyzservices.TileProvider(name="local_tile_provider", url=LOCAL_URL_TILE, attribution="")
expected_output = {'name': 'local_tile_provider', 'url': LOCAL_URL_TILE, 'attribution': ''}
providers_mocked.query_name.return_value = local_tile_provider
assert get_url_tile("OpenStreetMap") == expected_output
local_url = 'http://localhost:8000/{parameter}/{z}/{x}/{y}.png'
local_tile_provider = xyzservices.TileProvider(
name="local_tile_provider_param", url=local_url, attribution="", parameter="lamda_handler"
)
providers_mocked.query_name.return_value = local_tile_provider
assert get_url_tile("OpenStreetMap.HOT") == {
"parameter": "lamda_handler", 'name': 'local_tile_provider_param', 'url': local_url, 'attribution': ''
}
@staticmethod
def test_get_url_tile_real():
from samgis_lisa_on_cuda.io.wrappers_helpers import get_url_tile
assert get_url_tile("OpenStreetMap") == {
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
'html_attribution': '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors',
'attribution': '(C) OpenStreetMap contributors',
'name': 'OpenStreetMap.Mapnik'}
html_attribution_hot = '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors, '
html_attribution_hot += 'Tiles style by <a href="https://www.hotosm.org/" target="_blank">Humanitarian '
html_attribution_hot += 'OpenStreetMap Team</a> hosted by <a href="https://openstreetmap.fr/" target="_blank">'
html_attribution_hot += 'OpenStreetMap France</a>'
attribution_hot = '(C) OpenStreetMap contributors, Tiles style by Humanitarian OpenStreetMap Team hosted by '
attribution_hot += 'OpenStreetMap France'
assert get_url_tile("OpenStreetMap.HOT") == {
'url': 'https://{s}.tile.openstreetmap.fr/hot/{z}/{x}/{y}.png', 'max_zoom': 19,
'html_attribution': html_attribution_hot, 'attribution': attribution_hot, 'name': 'OpenStreetMap.HOT'
}
|