aletrn commited on
Commit
ee50e01
·
1 Parent(s): 3d82071

[debug] first test on lambda for tms2geotiff functions

Browse files
dockerfiles/dockerfile-lambda-gdal-runner CHANGED
@@ -11,11 +11,10 @@ RUN echo "ENV RIE: $RIE ..."
11
  WORKDIR ${LAMBDA_TASK_ROOT}
12
  COPY requirements.txt ${LAMBDA_TASK_ROOT}/requirements.txt
13
 
14
- RUN apt update && apt install -y curl python3-pip libgl1
15
  RUN which python
16
  RUN python --version
17
- RUN python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
18
- RUN python -m pip install -r ${LAMBDA_TASK_ROOT}/requirements.txt --target ${LAMBDA_TASK_ROOT}
19
 
20
  RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
21
  RUN chmod +x /usr/local/bin/aws-lambda-rie
 
11
  WORKDIR ${LAMBDA_TASK_ROOT}
12
  COPY requirements.txt ${LAMBDA_TASK_ROOT}/requirements.txt
13
 
14
+ RUN apt update && apt install -y curl python3-pip
15
  RUN which python
16
  RUN python --version
17
+ RUN python -m pip install pillow awslambdaric aws-lambda-powertools httpx jmespath
 
18
 
19
  RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
20
  RUN chmod +x /usr/local/bin/aws-lambda-rie
src/app.py CHANGED
@@ -1,42 +1,16 @@
1
  import json
2
  import time
3
  from http import HTTPStatus
4
- from typing import Dict, List
5
 
6
  from aws_lambda_powertools.event_handler import content_types
7
  from aws_lambda_powertools.utilities.typing import LambdaContext
8
- from geojson_pydantic import FeatureCollection, Feature, Polygon
9
- from pydantic import BaseModel, ValidationError
10
 
11
  from src import app_logger
12
- from src.prediction_api.samgeo_predictors import samgeo_fast_predict
13
- from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM, SOURCE_TYPE
14
- from src.utilities.utilities import base64_decode
15
 
16
- PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
17
 
18
-
19
- class LatLngTupleLeaflet(BaseModel):
20
- lat: float
21
- lng: float
22
-
23
-
24
- class RequestBody(BaseModel):
25
- model: str = MODEL_NAME
26
- ne: LatLngTupleLeaflet
27
- source_type: str = SOURCE_TYPE
28
- sw: LatLngTupleLeaflet
29
- zoom: float = ZOOM
30
-
31
-
32
- class ResponseBody(BaseModel):
33
- duration_run: float = None
34
- geojson: Dict = None
35
- message: str = None
36
- request_id: str = None
37
-
38
-
39
- def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
40
  """
41
  Return a response for frontend clients.
42
 
@@ -51,9 +25,9 @@ def get_response(status: int, start_time: float, request_id: str, response_body:
51
 
52
  """
53
  app_logger.info(f"response_body:{response_body}.")
54
- response_body.duration_run = time.time() - start_time
55
- response_body.message = CUSTOM_RESPONSE_MESSAGES[status]
56
- response_body.request_id = request_id
57
 
58
  response = {
59
  "statusCode": status,
@@ -65,24 +39,6 @@ def get_response(status: int, start_time: float, request_id: str, response_body:
65
  return json.dumps(response)
66
 
67
 
68
- def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
69
- model_name = request_input["model"] if "model" in request_input else MODEL_NAME
70
- zoom = request_input["zoom"] if "zoom" in request_input else ZOOM
71
- source_type = request_input["source_type"] if "zoom" in request_input else SOURCE_TYPE
72
- app_logger.info(f"try to validate input request {request_input}...")
73
- request_body = RequestBody(ne=request_input["ne"], sw=request_input["sw"], model=model_name, zoom=zoom, source_type=source_type)
74
- app_logger.info(f"unpacking {request_body}...")
75
- return {
76
- "bbox": [
77
- request_body.ne.lat, request_body.sw.lat,
78
- request_body.ne.lng, request_body.sw.lng
79
- ],
80
- "model": request_body.model,
81
- "zoom": request_body.zoom,
82
- "source_type": request_body.source_type
83
- }
84
-
85
-
86
  def lambda_handler(event: dict, context: LambdaContext):
87
  app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
88
  start_time = time.time()
@@ -95,37 +51,17 @@ def lambda_handler(event: dict, context: LambdaContext):
95
  app_logger.info(f"context:{context}...")
96
 
97
  try:
98
- body = event["body"]
99
- except Exception as e_constants1:
100
- app_logger.error(f"e_constants1:{e_constants1}.")
101
- body = event
102
-
103
- app_logger.info(f"body: {type(body)}, {body}...")
104
-
105
- if isinstance(body, str):
106
- body_decoded_str = base64_decode(body)
107
- app_logger.info(f"body_decoded_str: {type(body_decoded_str)}, {body_decoded_str}...")
108
- body = json.loads(body_decoded_str)
109
-
110
- app_logger.info(f"body:{body}...")
111
-
112
- try:
113
- body_request = get_parsed_bbox_points(body)
114
- app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
115
- output_geojson_dict = samgeo_fast_predict(
116
- bbox=body_request["bbox"], model_name=body_request["model"], zoom=body_request["zoom"], source_type=body_request["source_type"]
117
- )
118
-
119
- # raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
120
- PolygonFeatureCollectionModel(**output_geojson_dict)
121
- body_response = ResponseBody(geojson=output_geojson_dict)
122
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
123
- except ValidationError as ve:
124
  app_logger.error(f"validation error:{ve}.")
125
- response = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, context.aws_request_id, ResponseBody())
126
  except Exception as e:
127
  app_logger.error(f"exception:{e}.")
128
- response = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, context.aws_request_id, ResponseBody())
129
 
130
  app_logger.info(f"response_dumped:{response}...")
131
  return response
 
1
  import json
2
  import time
3
  from http import HTTPStatus
 
4
 
5
  from aws_lambda_powertools.event_handler import content_types
6
  from aws_lambda_powertools.utilities.typing import LambdaContext
 
 
7
 
8
  from src import app_logger
9
+ from src.io.tms2geotiff import download_extent
10
+ from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, DEFAULT_TMS
 
11
 
 
12
 
13
+ def get_response(status: int, start_time: float, request_id: str, response_body = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
  Return a response for frontend clients.
16
 
 
25
 
26
  """
27
  app_logger.info(f"response_body:{response_body}.")
28
+ response_body["duration_run"] = time.time() - start_time
29
+ response_body["message"] = CUSTOM_RESPONSE_MESSAGES[status]
30
+ response_body["request_id"] = request_id
31
 
32
  response = {
33
  "statusCode": status,
 
39
  return json.dumps(response)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def lambda_handler(event: dict, context: LambdaContext):
43
  app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
44
  start_time = time.time()
 
51
  app_logger.info(f"context:{context}...")
52
 
53
  try:
54
+ pt0 = 45.699, 127.1
55
+ pt1 = 30.1, 148.492
56
+ img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], 6)
57
+ body_response = {"geojson": {"img_size": img.size}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
59
+ except Exception as ve:
60
  app_logger.error(f"validation error:{ve}.")
61
+ response = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, context.aws_request_id, {})
62
  except Exception as e:
63
  app_logger.error(f"exception:{e}.")
64
+ response = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, context.aws_request_id, {})
65
 
66
  app_logger.info(f"response_dumped:{response}...")
67
  return response
src/io/tms2geotiff.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import io
5
+ import os
6
+ import re
7
+ import math
8
+ import time
9
+ import sqlite3
10
+ import argparse
11
+ import itertools
12
+ import concurrent.futures
13
+
14
+ from PIL import Image
15
+ from PIL import TiffImagePlugin
16
+
17
+ from src.utilities.constants import EARTH_EQUATORIAL_RADIUS, WKT_3857, DEFAULT_TMS
18
+
19
+
20
+ Image.MAX_IMAGE_PIXELS = None
21
+
22
+
23
+ try:
24
+ import httpx
25
+ SESSION = httpx.Client()
26
+ except ImportError:
27
+ import requests
28
+ SESSION = requests.Session()
29
+
30
+
31
+ SESSION.headers.update({
32
+ "Accept": "*/*",
33
+ "Accept-Encoding": "gzip, deflate",
34
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; rv:91.0) Gecko/20100101 Firefox/91.0",
35
+ })
36
+
37
+ re_coords_split = re.compile('[ ,;]+')
38
+
39
+
40
+ def from4326_to3857(lat, lon):
41
+ xtile = math.radians(lon) * EARTH_EQUATORIAL_RADIUS
42
+ ytile = math.log(math.tan(math.radians(45 + lat / 2.0))) * EARTH_EQUATORIAL_RADIUS
43
+ return (xtile, ytile)
44
+
45
+
46
+ def deg2num(lat, lon, zoom):
47
+ n = 2 ** zoom
48
+ xtile = ((lon + 180) / 360 * n)
49
+ ytile = (1 - math.asinh(math.tan(math.radians(lat))) / math.pi) * n / 2
50
+ return (xtile, ytile)
51
+
52
+
53
+ def is_empty(im):
54
+ extrema = im.getextrema()
55
+ if len(extrema) >= 3:
56
+ if len(extrema) > 3 and extrema[-1] == (0, 0):
57
+ return True
58
+ for ext in extrema[:3]:
59
+ if ext != (0, 0):
60
+ return False
61
+ return True
62
+ else:
63
+ return extrema[0] == (0, 0)
64
+
65
+
66
+ def mbtiles_init(dbname):
67
+ db = sqlite3.connect(dbname, isolation_level=None)
68
+ cur = db.cursor()
69
+ cur.execute("BEGIN")
70
+ cur.execute("CREATE TABLE IF NOT EXISTS metadata (name TEXT PRIMARY KEY, value TEXT)")
71
+ cur.execute("CREATE TABLE IF NOT EXISTS tiles ("
72
+ "zoom_level INTEGER NOT NULL, "
73
+ "tile_column INTEGER NOT NULL, "
74
+ "tile_row INTEGER NOT NULL, "
75
+ "tile_data BLOB NOT NULL, "
76
+ "UNIQUE (zoom_level, tile_column, tile_row)"
77
+ ")")
78
+ cur.execute("COMMIT")
79
+ return db
80
+
81
+
82
+ def paste_tile(bigim, base_size, tile, corner_xy, bbox):
83
+ if tile is None:
84
+ return bigim
85
+ im = Image.open(io.BytesIO(tile))
86
+ mode = 'RGB' if im.mode == 'RGB' else 'RGBA'
87
+ size = im.size
88
+ if bigim is None:
89
+ base_size[0] = size[0]
90
+ base_size[1] = size[1]
91
+ newim = Image.new(mode, (
92
+ size[0]*(bbox[2]-bbox[0]), size[1]*(bbox[3]-bbox[1])))
93
+ else:
94
+ newim = bigim
95
+
96
+ dx = abs(corner_xy[0] - bbox[0])
97
+ dy = abs(corner_xy[1] - bbox[1])
98
+ xy0 = (size[0]*dx, size[1]*dy)
99
+ if mode == 'RGB':
100
+ newim.paste(im, xy0)
101
+ else:
102
+ if im.mode != mode:
103
+ im = im.convert(mode)
104
+ if not is_empty(im):
105
+ newim.paste(im, xy0)
106
+ im.close()
107
+ return newim
108
+
109
+
110
+ def get_tile(url):
111
+ retry = 3
112
+ while 1:
113
+ try:
114
+ r = SESSION.get(url, timeout=60)
115
+ break
116
+ except Exception:
117
+ retry -= 1
118
+ if not retry:
119
+ raise
120
+ if r.status_code == 404:
121
+ return None
122
+ elif not r.content:
123
+ return None
124
+ r.raise_for_status()
125
+ return r.content
126
+
127
+
128
+ def print_progress(progress, total, done=False):
129
+ if done:
130
+ print('Downloaded image %d/%d, %.2f%%' % (progress, total, progress*100/total))
131
+
132
+
133
+ class ProgressBar:
134
+ def __init__(self, use_tqdm=True):
135
+ self._tqdm_fn = None
136
+ self.tqdm_bar = None
137
+ self.tqdm_progress = 0
138
+ if use_tqdm:
139
+ try:
140
+ import tqdm
141
+ self._tqdm_fn = lambda total: tqdm.tqdm(
142
+ total=total, unit='img')
143
+ except ImportError:
144
+ pass
145
+
146
+ def print_progress(self, progress, total, done=False):
147
+ if self.tqdm_bar is None and self._tqdm_fn:
148
+ self.tqdm_bar = self._tqdm_fn(total)
149
+ if not done:
150
+ return
151
+ if self.tqdm_bar is None:
152
+ print_progress(progress, total, done)
153
+ elif progress > self.tqdm_progress:
154
+ delta = progress - self.tqdm_progress
155
+ self.tqdm_bar.update(delta)
156
+ self.tqdm_progress = progress
157
+
158
+ def close(self):
159
+ if self.tqdm_bar:
160
+ self.tqdm_bar.close()
161
+ else:
162
+ print('\nDone.')
163
+
164
+
165
+ def mbtiles_save(db, img_data, xy, zoom, img_format):
166
+ if not img_data:
167
+ return
168
+ im = Image.open(io.BytesIO(img_data))
169
+ if im.format == 'PNG':
170
+ current_format = 'png'
171
+ elif im.format == 'JPEG':
172
+ current_format = 'jpg'
173
+ elif im.format == 'WEBP':
174
+ current_format = 'webp'
175
+ else:
176
+ current_format = 'image/' + im.format.lower()
177
+ x, y = xy
178
+ y = 2**zoom - 1 - y
179
+ cur = db.cursor()
180
+ if img_format is None or img_format == current_format:
181
+ cur.execute("REPLACE INTO tiles VALUES (?,?,?,?)", (
182
+ zoom, x, y, img_data))
183
+ return img_format or current_format
184
+ buf = io.BytesIO()
185
+ if img_format == 'png':
186
+ im.save(buf, 'PNG')
187
+ elif img_format == 'jpg':
188
+ im.save(buf, 'JPEG', quality=93)
189
+ elif img_format == 'webp':
190
+ im.save(buf, 'WEBP')
191
+ else:
192
+ im.save(buf, img_format.split('/')[-1].upper())
193
+ cur.execute("REPLACE INTO tiles VALUES (?,?,?,?)", (
194
+ zoom, x, y, buf.getvalue()))
195
+ return img_format
196
+
197
+
198
+ def download_extent(
199
+ source, lat0, lon0, lat1, lon1, zoom,
200
+ mbtiles=None, save_image=True,
201
+ progress_callback=print_progress,
202
+ callback_interval=0.05
203
+ ):
204
+ x0, y0 = deg2num(lat0, lon0, zoom)
205
+ x1, y1 = deg2num(lat1, lon1, zoom)
206
+ if x0 > x1:
207
+ x0, x1 = x1, x0
208
+ if y0 > y1:
209
+ y0, y1 = y1, y0
210
+
211
+ db = None
212
+ mbt_img_format = None
213
+ if mbtiles:
214
+ db = mbtiles_init(mbtiles)
215
+ cur = db.cursor()
216
+ cur.execute("BEGIN")
217
+ cur.execute("REPLACE INTO metadata VALUES ('name', ?)", (source,))
218
+ cur.execute("REPLACE INTO metadata VALUES ('type', 'overlay')")
219
+ cur.execute("REPLACE INTO metadata VALUES ('version', '1.1')")
220
+ cur.execute("REPLACE INTO metadata VALUES ('description', ?)", (source,))
221
+ cur.execute("SELECT value FROM metadata WHERE name='format'")
222
+ row = cur.fetchone()
223
+ if row and row[0]:
224
+ mbt_img_format = row[0]
225
+ else:
226
+ cur.execute("REPLACE INTO metadata VALUES ('format', 'png')")
227
+
228
+ lat_min = min(lat0, lat1)
229
+ lat_max = max(lat0, lat1)
230
+ lon_min = min(lon0, lon1)
231
+ lon_max = max(lon0, lon1)
232
+ bounds = [lon_min, lat_min, lon_max, lat_max]
233
+ cur.execute("SELECT value FROM metadata WHERE name='bounds'")
234
+ row = cur.fetchone()
235
+ if row and row[0]:
236
+ last_bounds = [float(x) for x in row[0].split(',')]
237
+ bounds[0] = min(last_bounds[0], bounds[0])
238
+ bounds[1] = min(last_bounds[1], bounds[1])
239
+ bounds[2] = max(last_bounds[2], bounds[2])
240
+ bounds[3] = max(last_bounds[3], bounds[3])
241
+ cur.execute("REPLACE INTO metadata VALUES ('bounds', ?)", (
242
+ ",".join(map(str, bounds)),))
243
+ cur.execute("REPLACE INTO metadata VALUES ('center', ?)", ("%s,%s,%d" % (
244
+ (lon_max + lon_min)/2, (lat_max + lat_min)/2, zoom),))
245
+ cur.execute("""
246
+ INSERT INTO metadata VALUES ('minzoom', ?)
247
+ ON CONFLICT(name) DO UPDATE SET value=excluded.value
248
+ WHERE CAST(excluded.value AS INTEGER)<CAST(metadata.value AS INTEGER)
249
+ """, (str(zoom),))
250
+ cur.execute("""
251
+ INSERT INTO metadata VALUES ('maxzoom', ?)
252
+ ON CONFLICT(name) DO UPDATE SET value=excluded.value
253
+ WHERE CAST(excluded.value AS INTEGER)>CAST(metadata.value AS INTEGER)
254
+ """, (str(zoom),))
255
+ cur.execute("COMMIT")
256
+
257
+ corners = tuple(itertools.product(
258
+ range(math.floor(x0), math.ceil(x1)),
259
+ range(math.floor(y0), math.ceil(y1))))
260
+ totalnum = len(corners)
261
+ futures = {}
262
+ done_num = 0
263
+ progress_callback(done_num, totalnum, False)
264
+ last_done_num = 0
265
+ last_callback = time.monotonic()
266
+ cancelled = False
267
+ with concurrent.futures.ThreadPoolExecutor(5) as executor:
268
+ for x, y in corners:
269
+ future = executor.submit(get_tile, source.format(z=zoom, x=x, y=y))
270
+ futures[future] = (x, y)
271
+ bbox = (math.floor(x0), math.floor(y0), math.ceil(x1), math.ceil(y1))
272
+ bigim = None
273
+ base_size = [256, 256]
274
+ while futures:
275
+ done, not_done = concurrent.futures.wait(
276
+ futures.keys(), timeout=callback_interval,
277
+ return_when=concurrent.futures.FIRST_COMPLETED
278
+ )
279
+ cur = None
280
+ if mbtiles:
281
+ cur = db.cursor()
282
+ cur.execute("BEGIN")
283
+ for fut in done:
284
+ img_data = fut.result()
285
+ xy = futures[fut]
286
+ if save_image:
287
+ bigim = paste_tile(bigim, base_size, img_data, xy, bbox)
288
+ if mbtiles:
289
+ new_format = mbtiles_save(db, img_data, xy, zoom, mbt_img_format)
290
+ if not mbt_img_format:
291
+ cur.execute(
292
+ "UPDATE metadata SET value=? WHERE name='format'",
293
+ (new_format,))
294
+ mbt_img_format = new_format
295
+ del futures[fut]
296
+ done_num += 1
297
+ if mbtiles:
298
+ cur.execute("COMMIT")
299
+ if time.monotonic() > last_callback + callback_interval:
300
+ try:
301
+ progress_callback(done_num, totalnum, (done_num > last_done_num))
302
+ except TaskCancelled:
303
+ for fut in futures.keys():
304
+ fut.cancel()
305
+ futures.clear()
306
+ cancelled = True
307
+ break
308
+ last_callback = time.monotonic()
309
+ last_done_num = done_num
310
+ if cancelled:
311
+ raise TaskCancelled()
312
+ progress_callback(done_num, totalnum, True)
313
+
314
+ if not save_image:
315
+ return None, None
316
+
317
+ xfrac = x0 - bbox[0]
318
+ yfrac = y0 - bbox[1]
319
+ x2 = round(base_size[0]*xfrac)
320
+ y2 = round(base_size[1]*yfrac)
321
+ imgw = round(base_size[0]*(x1-x0))
322
+ imgh = round(base_size[1]*(y1-y0))
323
+ retim = bigim.crop((x2, y2, x2+imgw, y2+imgh))
324
+ if retim.mode == 'RGBA' and retim.getextrema()[3] == (255, 255):
325
+ retim = retim.convert('RGB')
326
+ bigim.close()
327
+ xp0, yp0 = from4326_to3857(lat0, lon0)
328
+ xp1, yp1 = from4326_to3857(lat1, lon1)
329
+ pwidth = abs(xp1 - xp0) / retim.size[0]
330
+ pheight = abs(yp1 - yp0) / retim.size[1]
331
+ matrix = (min(xp0, xp1), pwidth, 0, max(yp0, yp1), 0, -pheight)
332
+ return retim, matrix
333
+
334
+
335
+ def generate_tiffinfo(matrix):
336
+ ifd = TiffImagePlugin.ImageFileDirectory_v2()
337
+ # GeoKeyDirectoryTag
338
+ gkdt = [
339
+ 1, 1,
340
+ 0, # GeoTIFF 1.0
341
+ 0, # NumberOfKeys
342
+ ]
343
+ # KeyID, TIFFTagLocation, KeyCount, ValueOffset
344
+ geokeys = [
345
+ # GTModelTypeGeoKey
346
+ (1024, 0, 1, 1), # 2D projected coordinate reference system
347
+ # GTRasterTypeGeoKey
348
+ (1025, 0, 1, 1), # PixelIsArea
349
+ # GTCitationGeoKey
350
+ (1026, 34737, 25, 0),
351
+ # GeodeticCitationGeoKey
352
+ (2049, 34737, 7, 25),
353
+ # GeogAngularUnitsGeoKey
354
+ (2054, 0, 1, 9102), # degree
355
+ # ProjectedCRSGeoKey
356
+ (3072, 0, 1, 3857),
357
+ # ProjLinearUnitsGeoKey
358
+ (3076, 0, 1, 9001), # metre
359
+ ]
360
+ gkdt[3] = len(geokeys)
361
+ ifd.tagtype[34735] = 3 # short
362
+ ifd[34735] = tuple(itertools.chain(gkdt, *geokeys))
363
+ # GeoDoubleParamsTag
364
+ ifd.tagtype[34736] = 12 # double
365
+ # GeoAsciiParamsTag
366
+ ifd.tagtype[34737] = 1 # byte
367
+ ifd[34737] = b'WGS 84 / Pseudo-Mercator|WGS 84|\x00'
368
+ a, b, c, d, e, f = matrix
369
+ # ModelPixelScaleTag
370
+ ifd.tagtype[33550] = 12 # double
371
+ # ModelTiepointTag
372
+ ifd.tagtype[33922] = 12 # double
373
+ # ModelTransformationTag
374
+ ifd.tagtype[34264] = 12 # double
375
+ # This matrix tag should not be used
376
+ # if the ModelTiepointTag and the ModelPixelScaleTag are already defined
377
+ if c == 0 and e == 0:
378
+ ifd[33550] = (b, -f, 0.0)
379
+ ifd[33922] = (0.0, 0.0, 0.0, a, d, 0.0)
380
+ else:
381
+ ifd[34264] = (
382
+ b, c, 0.0, a,
383
+ e, f, 0.0, d,
384
+ 0.0, 0.0, 0.0, 0.0,
385
+ 0.0, 0.0, 0.0, 1.0
386
+ )
387
+ return ifd
388
+
389
+
390
+ def img_memorysize(img):
391
+ return img.size[0] * img.size[1] * len(img.getbands())
392
+
393
+
394
+ def save_image_fn(img, filename, matrix, **params):
395
+ wld_ext = {
396
+ '.gif': '.gfw',
397
+ '.jpg': '.jgw',
398
+ '.jpeg': '.jgw',
399
+ '.jp2': '.j2w',
400
+ '.png': '.pgw',
401
+ '.tif': '.tfw',
402
+ '.tiff': '.tfw',
403
+ }
404
+ basename, ext = os.path.splitext(filename)
405
+ ext = ext.lower()
406
+ wld_name = basename + wld_ext.get(ext, '.wld')
407
+ img_params = params.copy()
408
+ if ext == '.jpg':
409
+ img_params['quality'] = 92
410
+ img_params['optimize'] = True
411
+ elif ext == '.png':
412
+ img_params['optimize'] = True
413
+ elif ext.startswith('.tif'):
414
+ if img_memorysize(img) >= 4*1024*1024*1024:
415
+ # BigTIFF
416
+ return save_geotiff_gdal(img, filename, matrix)
417
+ img_params['compression'] = 'tiff_adobe_deflate'
418
+ img_params['tiffinfo'] = generate_tiffinfo(matrix)
419
+ img.save(filename, **img_params)
420
+ if not ext.startswith('.tif'):
421
+ with open(wld_name, 'w', encoding='utf-8') as f_wld:
422
+ a, b, c, d, e, f = matrix
423
+ f_wld.write('\n'.join(map(str, (b, e, c, f, a, d, ''))))
424
+ return img
425
+
426
+
427
+ def save_geotiff_gdal(img, filename, matrix):
428
+ if 'GDAL_DATA' in os.environ:
429
+ del os.environ['GDAL_DATA']
430
+ if 'PROJ_LIB' in os.environ:
431
+ del os.environ['PROJ_LIB']
432
+
433
+ import numpy
434
+ from osgeo import gdal
435
+ gdal.UseExceptions()
436
+
437
+ imgbands = len(img.getbands())
438
+ driver = gdal.GetDriverByName('GTiff')
439
+ gdal_options = ['COMPRESS=DEFLATE', 'PREDICTOR=2', 'ZLEVEL=9', 'TILED=YES']
440
+ if img_memorysize(img) >= 4*1024*1024*1024:
441
+ gdal_options.append('BIGTIFF=YES')
442
+ if img_memorysize(img) >= 50*1024*1024:
443
+ gdal_options.append('NUM_THREADS=%d' % max(1, os.cpu_count()))
444
+
445
+ gtiff = driver.Create(filename, img.size[0], img.size[1],
446
+ imgbands, gdal.GDT_Byte,
447
+ options=gdal_options)
448
+ gtiff.SetGeoTransform(matrix)
449
+ gtiff.SetProjection(WKT_3857)
450
+ for band in range(imgbands):
451
+ array = numpy.array(img.getdata(band), dtype='u8')
452
+ array = array.reshape((img.size[1], img.size[0]))
453
+ band = gtiff.GetRasterBand(band + 1)
454
+ band.WriteArray(array)
455
+ gtiff.FlushCache()
456
+ return img
457
+
458
+
459
+ def save_image_auto(img, filename, matrix, use_gdal=False, **params):
460
+ ext = os.path.splitext(filename)[1].lower()
461
+ if ext in ('.tif', '.tiff') and use_gdal:
462
+ return save_geotiff_gdal(img, filename, matrix)
463
+ else:
464
+ return save_image_fn(img, filename, matrix, **params)
465
+
466
+
467
+ class TaskCancelled(RuntimeError):
468
+ pass
469
+
470
+
471
+ def parse_extent(s):
472
+ try:
473
+ coords_text = re_coords_split.split(s)
474
+ return (float(coords_text[1]), float(coords_text[0]),
475
+ float(coords_text[3]), float(coords_text[2]))
476
+ except (IndexError, ValueError):
477
+ raise ValueError("Invalid extent, should be: min_lon,min_lat,max_lon,max_lat")
478
+
479
+
480
+ def gui():
481
+ import tkinter as tk
482
+ import tkinter.ttk as ttk
483
+ import tkinter.messagebox
484
+
485
+ root_tk = tk.Tk()
486
+
487
+ def cmd_get_save_file():
488
+ result = root_tk.tk.eval("""tk_getSaveFile -filetypes {
489
+ {{GeoTIFF} {.tiff}}
490
+ {{JPG} {.jpg}}
491
+ {{PNG} {.png}}
492
+ {{All Files} *}
493
+ } -defaultextension .tiff""")
494
+ if result:
495
+ v_output.set(result)
496
+
497
+ def cmd_get_save_mbtiles():
498
+ result = root_tk.tk.eval("""tk_getSaveFile -filetypes {
499
+ {{MBTiles} {.mbtiles}}
500
+ {{All Files} *}
501
+ } -defaultextension .tiff""")
502
+ if result:
503
+ v_mbtiles.set(result)
504
+
505
+ frame = ttk.Frame(root_tk, padding=8)
506
+ frame.grid(column=0, row=0, sticky='nsew')
507
+ frame.master.title('Download TMS image')
508
+ frame.master.resizable(0, 0)
509
+ l_url = ttk.Label(frame, width=50, text="URL: (with {x}, {y}, {z})")
510
+ l_url.grid(column=0, row=0, columnspan=3, sticky='w', pady=(0, 2))
511
+ v_url = tk.StringVar()
512
+ e_url = ttk.Entry(frame, textvariable=v_url)
513
+ e_url.grid(column=0, row=1, columnspan=3, sticky='we', pady=(0, 5))
514
+ l_extent = ttk.Label(frame, text="Extent: (min_lon,min_lat,max_lon,max_lat)")
515
+ l_extent.grid(column=0, row=2, columnspan=3, sticky='w', pady=(0, 2))
516
+ v_extent = tk.StringVar()
517
+ e_extent = ttk.Entry(frame, width=50, textvariable=v_extent)
518
+ e_extent.grid(column=0, row=3, columnspan=3, sticky='we', pady=(0, 5))
519
+ l_zoom = ttk.Label(frame, width=5, text="Zoom:")
520
+ l_zoom.grid(column=0, row=4, sticky='w')
521
+ v_zoom = tk.StringVar()
522
+ v_zoom.set('13')
523
+ e_zoom = ttk.Spinbox(frame, width=10, textvariable=v_zoom, **{
524
+ 'from': 1, 'to': 19, 'increment': 1
525
+ })
526
+ e_zoom.grid(column=1, row=4, sticky='w')
527
+ l_output = ttk.Label(frame, width=10, text="Output:")
528
+ l_output.grid(column=0, row=5, sticky='w')
529
+ v_output = tk.StringVar()
530
+ e_output = ttk.Entry(frame, width=30, textvariable=v_output)
531
+ e_output.grid(column=1, row=5, sticky='we')
532
+ b_output = ttk.Button(frame, text='...', width=3, command=cmd_get_save_file)
533
+ b_output.grid(column=2, row=5, sticky='we')
534
+ l_mbtiles = ttk.Label(frame, width=10, text="MBTiles:")
535
+ l_mbtiles.grid(column=0, row=6, sticky='w')
536
+ v_mbtiles = tk.StringVar()
537
+ e_mbtiles = ttk.Entry(frame, width=30, textvariable=v_mbtiles)
538
+ e_mbtiles.grid(column=1, row=6, sticky='we')
539
+ b_mbtiles = ttk.Button(frame, text='...', width=3, command=cmd_get_save_mbtiles)
540
+ b_mbtiles.grid(column=2, row=6, sticky='we')
541
+ p_progress = ttk.Progressbar(frame, mode='determinate')
542
+ p_progress.grid(column=0, row=7, columnspan=3, sticky='we', pady=(5, 2))
543
+
544
+ started = False
545
+ stop_download = False
546
+
547
+ def reset():
548
+ b_download.configure(
549
+ text='Download', state='normal', command=cmd_download)
550
+ root_tk.update()
551
+
552
+ def update_progress(progress, total, done):
553
+ nonlocal started, stop_download
554
+ if not started:
555
+ if done:
556
+ p_progress.configure(maximum=total, value=progress)
557
+ else:
558
+ p_progress.configure(maximum=total)
559
+ started = True
560
+ elif done:
561
+ p_progress.configure(value=progress)
562
+ root_tk.update()
563
+ if stop_download:
564
+ raise TaskCancelled()
565
+
566
+ def cmd_download():
567
+ nonlocal started, stop_download
568
+ started = False
569
+ stop_download = False
570
+ b_download.configure(text='Cancel', command=cmd_cancel)
571
+ root_tk.update()
572
+ try:
573
+ url = v_url.get().strip()
574
+ args = [url]
575
+ args.extend(parse_extent(v_extent.get()))
576
+ args.append(int(v_zoom.get()))
577
+ filename = v_output.get()
578
+ mbtiles = v_mbtiles.get()
579
+ kwargs = {'mbtiles': mbtiles, 'save_image': bool(filename)}
580
+ if not all(args) or not any((filename, mbtiles)):
581
+ raise ValueError("Empty input")
582
+ except (TypeError, ValueError, IndexError) as ex:
583
+ reset()
584
+ tkinter.messagebox.showerror(
585
+ title='tms2geotiff',
586
+ message="Invalid input: %s: %s" % (type(ex).__name__, ex),
587
+ master=frame
588
+ )
589
+ return
590
+ root_tk.update()
591
+ try:
592
+ img, matrix = download_extent(
593
+ *args, progress_callback=update_progress, **kwargs)
594
+ b_download.configure(text='Saving...', state='disabled')
595
+ root_tk.update()
596
+ if filename:
597
+ save_image_auto(img, filename, matrix)
598
+ reset()
599
+ except TaskCancelled:
600
+ reset()
601
+ tkinter.messagebox.showwarning(
602
+ title='tms2geotiff',
603
+ message="Download cancelled.",
604
+ master=frame
605
+ )
606
+ return
607
+ except Exception as ex:
608
+ reset()
609
+ tkinter.messagebox.showerror(
610
+ title='tms2geotiff',
611
+ message="%s: %s" % (type(ex).__name__, ex),
612
+ master=frame
613
+ )
614
+ return
615
+ tkinter.messagebox.showinfo(
616
+ title='tms2geotiff',
617
+ message="Download complete.",
618
+ master=frame
619
+ )
620
+
621
+ def cmd_cancel():
622
+ nonlocal started, stop_download
623
+ started = False
624
+ stop_download = True
625
+ reset()
626
+
627
+ b_download = ttk.Button(
628
+ width=15, text='Download', default='active', command=cmd_download)
629
+ b_download.grid(column=0, row=6, columnspan=3, pady=2)
630
+
631
+ root_tk.mainloop()
632
+
633
+
634
+ def downloader(input_args, input_parser):
635
+
636
+ download_args = [input_args.source]
637
+ try:
638
+ if input_args.extent:
639
+ download_args.extend(parse_extent(input_args.extent))
640
+ else:
641
+ coords0 = tuple(map(float, getattr(input_args, 'from').split(',')))
642
+ print("coords0:", coords0, "#")
643
+ coords1 = tuple(map(float, getattr(input_args, 'to').split(',')))
644
+ print("coords1:", coords1, "#")
645
+ download_args.extend((coords0[0], coords0[1], coords1[0], coords1[1]))
646
+ except Exception as e:
647
+ print(f"e:", e, "#")
648
+ input_parser.print_help()
649
+ return 1
650
+ download_args.append(input_args.zoom)
651
+ download_args.append(input_args.mbtiles)
652
+ download_args.append(bool(input_args.output))
653
+ progress_bar = ProgressBar()
654
+ download_args.append(progress_bar.print_progress)
655
+ img, matrix = download_extent(*download_args)
656
+ progress_bar.close()
657
+ if input_args.output:
658
+ print(f"Saving image to {input_args.output}.")
659
+ save_image_auto(img, input_args.output, matrix)
660
+ return 0
661
+
662
+
663
+ def main():
664
+ parser = argparse.ArgumentParser(
665
+ description="Merge TMS tiles to a big image.",
666
+ epilog="If no parameters are specified, it will open the GUI.")
667
+ parser.add_argument(
668
+ "-s", "--source", metavar='URL', default=DEFAULT_TMS,
669
+ help="TMS server url (default is OpenStreetMap: %s)" % DEFAULT_TMS)
670
+ parser.add_argument("-f", "--from", metavar='LAT,LON', help="one corner")
671
+ parser.add_argument("-t", "--to", metavar='LAT,LON', help="the other corner")
672
+ parser.add_argument("-e", "--extent",
673
+ metavar='min_lon,min_lat,max_lon,max_lat',
674
+ help="extent in one string (use either -e, or -f and -t)")
675
+ parser.add_argument("-z", "--zoom", type=int, help="zoom level")
676
+ parser.add_argument("-m", "--mbtiles", help="save MBTiles file")
677
+ parser.add_argument("-g", "--gui", action='store_true', help="show GUI")
678
+ parser.add_argument("output", nargs='?', help="output image file (can be omitted)")
679
+ args = parser.parse_args()
680
+ if args.gui or not getattr(args, 'zoom', None):
681
+ gui()
682
+ # parser.print_help()
683
+ return 1
684
+
685
+ downloader(args, parser)
686
+
687
+
688
+ if __name__ == '__main__':
689
+ # import sys
690
+ # sys.exit(main())
691
+ pt0 = 45.699, 127.1
692
+ pt1 = 30.1, 148.492
693
+ img_output_filename = "/Users/trincuz/workspace/segment-geospatial/tmp/japan_out_main.png"
694
+ img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], 6)
695
+
696
+ print(f"Saving image to {img_output_filename}.")
697
+ save_image_auto(img, img_output_filename, matrix)
src/utilities/constants.py CHANGED
@@ -25,3 +25,7 @@ CUSTOM_RESPONSE_MESSAGES = {
25
  MODEL_NAME = "FastSAM-s.pt"
26
  ZOOM = 13
27
  SOURCE_TYPE = "Satellite"
 
 
 
 
 
25
  MODEL_NAME = "FastSAM-s.pt"
26
  ZOOM = 13
27
  SOURCE_TYPE = "Satellite"
28
+
29
+ EARTH_EQUATORIAL_RADIUS = 6378137.0
30
+ DEFAULT_TMS = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
31
+ WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Mercator_1SP"],PARAMETER["central_meridian",0],PARAMETER["scale_factor",1],PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["X",EAST],AXIS["Y",NORTH],EXTENSION["PROJ4","+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 +lon_0=0.0 +x_0=0.0 +y_0=0 +k=1.0 +units=m +nadgrids=@null +wktext +no_defs"],AUTHORITY["EPSG","3857"]]'