aletrn commited on
Commit
142642e
·
1 Parent(s): 5a246f7

[test] add test cases, add function utility base64_encode()

Browse files
src/utilities/utilities.py CHANGED
@@ -1,17 +1,20 @@
1
  """Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
2
 
3
 
4
- def is_base64(sb):
 
 
 
 
 
 
 
 
 
5
  import base64
6
 
7
  try:
8
- if isinstance(sb, str):
9
- # If there's any unicode here, an exception will be thrown and the function will return false
10
- sb_bytes = bytes(sb, 'ascii')
11
- elif isinstance(sb, bytes):
12
- sb_bytes = sb
13
- else:
14
- raise ValueError("Argument must be string or bytes")
15
  return base64.b64encode(base64.b64decode(sb_bytes, validate=True)) == sb_bytes
16
  except ValueError:
17
  return False
@@ -24,3 +27,10 @@ def base64_decode(s):
24
  return base64.b64decode(s, validate=True).decode("utf-8")
25
 
26
  return s
 
 
 
 
 
 
 
 
1
  """Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
2
 
3
 
4
+ def prepare_base64_input(sb):
5
+ if isinstance(sb, str):
6
+ # If there's any unicode here, an exception will be thrown and the function will return false
7
+ return bytes(sb, 'ascii')
8
+ elif isinstance(sb, bytes):
9
+ return sb
10
+ raise ValueError("Argument must be string or bytes")
11
+
12
+
13
+ def is_base64(sb: str or bytes):
14
  import base64
15
 
16
  try:
17
+ sb_bytes = prepare_base64_input(sb)
 
 
 
 
 
 
18
  return base64.b64encode(base64.b64decode(sb_bytes, validate=True)) == sb_bytes
19
  except ValueError:
20
  return False
 
27
  return base64.b64decode(s, validate=True).decode("utf-8")
28
 
29
  return s
30
+
31
+
32
+ def base64_encode(sb: str or bytes):
33
+ import base64
34
+
35
+ sb_bytes = prepare_base64_input(sb)
36
+ return base64.b64encode(sb_bytes)
tests/io/test_lambda_helpers.py CHANGED
@@ -1,6 +1,8 @@
1
  import json
2
 
3
  from src.io.lambda_helpers import get_parsed_bbox_points, get_parsed_request_body
 
 
4
  from tests import TEST_EVENTS_FOLDER
5
 
6
 
@@ -12,3 +14,34 @@ def test_get_parsed_bbox_points():
12
  raw_body = get_parsed_request_body(**input_output["input"])
13
  output = get_parsed_bbox_points(raw_body)
14
  assert output == input_output["output"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
  from src.io.lambda_helpers import get_parsed_bbox_points, get_parsed_request_body
4
+ from src.utilities.type_hints import RawRequestInput
5
+ from src.utilities.utilities import base64_encode
6
  from tests import TEST_EVENTS_FOLDER
7
 
8
 
 
14
  raw_body = get_parsed_request_body(**input_output["input"])
15
  output = get_parsed_bbox_points(raw_body)
16
  assert output == input_output["output"]
17
+
18
+
19
+ def test_get_parsed_request_body():
20
+ input_event = {
21
+ "event": {
22
+ "bbox": {
23
+ "ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
24
+ "sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
25
+ },
26
+ "prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
27
+ "zoom": 10, "source_type": "Satellite", "debug": True
28
+ }
29
+ }
30
+ expected_output_dict = {
31
+ "bbox": {
32
+ "ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
33
+ "sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
34
+ },
35
+ "prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
36
+ "zoom": 10, "source_type": "Satellite", "debug": True
37
+ }
38
+ output = get_parsed_request_body(input_event["event"])
39
+ assert output == RawRequestInput.model_validate(input_event["event"])
40
+
41
+ input_event_str = json.dumps(input_event["event"])
42
+ output = get_parsed_request_body(input_event_str)
43
+ assert output == RawRequestInput.model_validate(expected_output_dict)
44
+
45
+ event = {"body": base64_encode(input_event_str).decode("utf-8")}
46
+ output = get_parsed_request_body(event)
47
+ assert output == RawRequestInput.model_validate(expected_output_dict)