Samhita commited on
Commit
3cc543c
1 Parent(s): 889cf23

add gantry code

Browse files

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

app.py CHANGED
@@ -1,7 +1,7 @@
1
  import json
2
  import mimetypes
3
  import os
4
- from typing import Tuple
5
 
6
  import gradio as gr
7
  import pandas as pd
@@ -9,10 +9,17 @@ import plotly
9
  import plotly.express as px
10
  import requests
11
  from dotenv import load_dotenv
 
 
12
 
13
  load_dotenv()
14
 
15
  URL = os.getenv("ENDPOINT")
 
 
 
 
 
16
 
17
 
18
  def get_plotly_graph(
@@ -21,7 +28,7 @@ def get_plotly_graph(
21
  lat_long_data = [[latitude, longitude, location]]
22
  map_df = pd.DataFrame(lat_long_data, columns=["latitude", "longitude", "location"])
23
 
24
- px.set_mapbox_access_token(os.getenv("MAPBOX_TOKEN"))
25
  fig = px.scatter_mapbox(
26
  map_df,
27
  lat="latitude",
@@ -36,7 +43,29 @@ def get_plotly_graph(
36
  return fig
37
 
38
 
39
- def image_gradio(img_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  data = json.loads(
41
  requests.post(
42
  f"{URL}predict-image",
@@ -50,13 +79,10 @@ def image_gradio(img_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
50
  ).text
51
  )
52
 
53
- location = data["location"]
54
- return data["location"], get_plotly_graph(
55
- latitude=data["latitude"], longitude=data["longitude"], location=location
56
- )
57
 
58
 
59
- def video_gradio(video_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
60
  data = json.loads(
61
  requests.post(
62
  f"{URL}predict-video",
@@ -70,13 +96,10 @@ def video_gradio(video_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
70
  ).text
71
  )
72
 
73
- location = data["location"]
74
- return location, get_plotly_graph(
75
- latitude=data["latitude"], longitude=data["longitude"], location=location
76
- )
77
 
78
 
79
- def url_gradio(url: str) -> Tuple[str, plotly.graph_objects.Figure]:
80
  data = json.loads(
81
  requests.post(
82
  f"{URL}predict-url",
@@ -85,32 +108,31 @@ def url_gradio(url: str) -> Tuple[str, plotly.graph_objects.Figure]:
85
  ).text
86
  )
87
 
88
- location = data["location"]
89
- return location, get_plotly_graph(
90
- latitude=data["latitude"], longitude=data["longitude"], location=location
91
- )
92
 
93
 
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# GeoLocator")
96
  gr.Markdown(
97
- "## An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗."
98
- )
99
- gr.Markdown(
100
- "Find the code powering this application [here](https://github.com/samhita-alla/geolocator)."
101
  )
102
  with gr.Tab("Image"):
103
  with gr.Row():
104
- img_input = gr.Image(type="filepath", label="im")
105
  with gr.Column():
106
  img_text_output = gr.Textbox(label="Location")
 
107
  img_plot = gr.Plot()
108
  img_text_button = gr.Button("Go locate!")
 
 
 
109
  with gr.Tab("Video"):
110
  with gr.Row():
111
- video_input = gr.Video(type="filepath", label="video")
112
  with gr.Column():
113
  video_text_output = gr.Textbox(label="Location")
 
114
  video_plot = gr.Plot()
115
  video_text_button = gr.Button("Go locate!")
116
  with gr.Tab("YouTube Link"):
@@ -118,19 +140,46 @@ with gr.Blocks() as demo:
118
  url_input = gr.Textbox(label="YouTube video link")
119
  with gr.Column():
120
  url_text_output = gr.Textbox(label="Location")
 
121
  url_plot = gr.Plot()
122
  url_text_button = gr.Button("Go locate!")
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  img_text_button.click(
125
- image_gradio, inputs=img_input, outputs=[img_text_output, img_plot]
 
 
126
  )
127
  video_text_button.click(
128
- video_gradio, inputs=video_input, outputs=[video_text_output, video_plot]
 
 
129
  )
130
  url_text_button.click(
131
- url_gradio, inputs=url_input, outputs=[url_text_output, url_plot]
 
 
132
  )
133
 
134
- examples = gr.Examples(".", inputs=[img_input, url_input])
 
 
 
 
135
 
136
  demo.launch()
 
1
  import json
2
  import mimetypes
3
  import os
4
+ from typing import Dict, Tuple, Union
5
 
6
  import gradio as gr
7
  import pandas as pd
 
9
  import plotly.express as px
10
  import requests
11
  from dotenv import load_dotenv
12
+ from gantry_callback.gantry_util import GantryImageToTextLogger
13
+ from gantry_callback.s3_util import make_unique_bucket_name
14
 
15
  load_dotenv()
16
 
17
  URL = os.getenv("ENDPOINT")
18
+ GANTRY_APP_NAME = os.getenv("GANTRY_APP_NAME")
19
+ GANTRY_KEY = os.getenv("GANTRY_API_KEY")
20
+ AWS_KEY = os.getenv("AWS_KEY")
21
+ AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
22
+ MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN")
23
 
24
 
25
  def get_plotly_graph(
 
28
  lat_long_data = [[latitude, longitude, location]]
29
  map_df = pd.DataFrame(lat_long_data, columns=["latitude", "longitude", "location"])
30
 
31
+ px.set_mapbox_access_token(MAPBOX_TOKEN)
32
  fig = px.scatter_mapbox(
33
  map_df,
34
  lat="latitude",
 
43
  return fig
44
 
45
 
46
+ def gradio_error():
47
+ raise gr.Error("Unable to detect the location!")
48
+
49
+
50
+ def get_outputs(
51
+ data: Dict[str, Union[str, float, None]]
52
+ ) -> Tuple[str, str, plotly.graph_objects.Figure]:
53
+ location, latitude, longitude = (
54
+ data["location"],
55
+ data["latitude"],
56
+ data["longitude"],
57
+ )
58
+ if location is None:
59
+ gradio_error()
60
+
61
+ return (
62
+ data["location"],
63
+ f"{latitude},{longitude}",
64
+ get_plotly_graph(latitude=latitude, longitude=longitude, location=location),
65
+ )
66
+
67
+
68
+ def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
69
  data = json.loads(
70
  requests.post(
71
  f"{URL}predict-image",
 
79
  ).text
80
  )
81
 
82
+ return get_outputs(data=data)
 
 
 
83
 
84
 
85
+ def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
86
  data = json.loads(
87
  requests.post(
88
  f"{URL}predict-video",
 
96
  ).text
97
  )
98
 
99
+ return get_outputs(data=data)
 
 
 
100
 
101
 
102
+ def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]:
103
  data = json.loads(
104
  requests.post(
105
  f"{URL}predict-url",
 
108
  ).text
109
  )
110
 
111
+ return get_outputs(data=data)
 
 
 
112
 
113
 
114
  with gr.Blocks() as demo:
115
  gr.Markdown("# GeoLocator")
116
  gr.Markdown(
117
+ "### An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗."
 
 
 
118
  )
119
  with gr.Tab("Image"):
120
  with gr.Row():
121
+ img_input = gr.Image(type="filepath", label="Image")
122
  with gr.Column():
123
  img_text_output = gr.Textbox(label="Location")
124
+ img_coordinates = gr.Textbox(label="Coordinates")
125
  img_plot = gr.Plot()
126
  img_text_button = gr.Button("Go locate!")
127
+ with gr.Row():
128
+ # Flag button
129
+ img_flag_button = gr.Button("Flag this output")
130
  with gr.Tab("Video"):
131
  with gr.Row():
132
+ video_input = gr.Video(type="filepath", label="Video")
133
  with gr.Column():
134
  video_text_output = gr.Textbox(label="Location")
135
+ video_coordinates = gr.Textbox(label="Coordinates")
136
  video_plot = gr.Plot()
137
  video_text_button = gr.Button("Go locate!")
138
  with gr.Tab("YouTube Link"):
 
140
  url_input = gr.Textbox(label="YouTube video link")
141
  with gr.Column():
142
  url_text_output = gr.Textbox(label="Location")
143
+ url_coordinates = gr.Textbox(label="Coordinates")
144
  url_plot = gr.Plot()
145
  url_text_button = gr.Button("Go locate!")
146
 
147
+ # Gantry flagging for image #
148
+ callback = GantryImageToTextLogger(application=GANTRY_APP_NAME, api_key=GANTRY_KEY)
149
+
150
+ callback.setup(
151
+ components=[img_input, img_text_output],
152
+ flagging_dir=make_unique_bucket_name(prefix=GANTRY_APP_NAME, seed="420"),
153
+ )
154
+
155
+ img_flag_button.click(
156
+ fn=lambda *args: callback.flag(args),
157
+ inputs=[img_input, img_text_output, img_coordinates],
158
+ outputs=None,
159
+ preprocess=False,
160
+ )
161
+ ###################
162
+
163
  img_text_button.click(
164
+ image_gradio,
165
+ inputs=img_input,
166
+ outputs=[img_text_output, img_coordinates, img_plot],
167
  )
168
  video_text_button.click(
169
+ video_gradio,
170
+ inputs=video_input,
171
+ outputs=[video_text_output, video_coordinates, video_plot],
172
  )
173
  url_text_button.click(
174
+ url_gradio,
175
+ inputs=url_input,
176
+ outputs=[url_text_output, url_coordinates, url_plot],
177
  )
178
 
179
+ examples = gr.Examples(".", inputs=[img_input, video_input, url_input])
180
+
181
+ gr.Markdown(
182
+ "Check out the [GitHub repository](https://github.com/samhita-alla/geolocator) that this demo is based off of."
183
+ )
184
 
185
  demo.launch()
gantry_callback/__init__.py ADDED
File without changes
gantry_callback/gantry_util.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Class to handle flagging in Gradio to Gantry.
3
+
4
+ Originally written by the FSDL educators at https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/app_gradio/flagging.py
5
+ that has been adjusted for the geolocator project.
6
+ """
7
+
8
+ import os
9
+ from typing import List, Optional, Union
10
+
11
+ import gantry
12
+ import gradio as gr
13
+ from gradio.components import Component
14
+ from smart_open import open
15
+
16
+ from .s3_util import (
17
+ add_access_policy,
18
+ enable_bucket_versioning,
19
+ get_or_create_bucket,
20
+ get_uri_of,
21
+ make_key,
22
+ )
23
+ from .string_img_util import read_b64_string
24
+
25
+
26
+ class GantryImageToTextLogger(gr.FlaggingCallback):
27
+ """
28
+ A FlaggingCallback that logs flagged image-to-text data to Gantry via S3.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ application: str,
34
+ version: Union[int, str, None] = None,
35
+ api_key: Optional[str] = None,
36
+ ):
37
+ """Logs image-to-text data that was flagged in Gradio to Gantry.
38
+
39
+ Images are logged to Amazon Web Services' Simple Storage Service (S3).
40
+
41
+ The flagging_dir provided to the Gradio interface is used to set the
42
+ name of the bucket on S3 into which images are logged.
43
+
44
+ See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK
45
+ for Python, boto3: https://realpython.com/python-boto3-aws-s3/
46
+
47
+ See https://gradio.app/docs/#flagging for details on how
48
+ flagging data is handled by Gradio.
49
+
50
+ See https://docs.gantry.io for information about logging data to Gantry.
51
+
52
+ Parameters
53
+ ----------
54
+ application
55
+ The name of the application on Gantry to which flagged data should be uploaded.
56
+ Gantry validates and monitors data per application.
57
+ version
58
+ The schema version to use during validation by Gantry. If not provided, Gantry
59
+ will use the latest version. A new version will be created if the provided version
60
+ does not exist yet.
61
+ api_key
62
+ Optionally, provide your Gantry API key here. Provided for convenience
63
+ when testing and developing locally or in notebooks. The API key can
64
+ alternatively be provided via the GANTRY_API_KEY environment variable.
65
+ """
66
+ self.application = application
67
+ self.version = version
68
+ gantry.init(api_key=api_key)
69
+
70
+ def setup(self, components: List[Component], flagging_dir: str):
71
+ """Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket."""
72
+ self._counter = 0
73
+ self.bucket = get_or_create_bucket(flagging_dir)
74
+ enable_bucket_versioning(self.bucket)
75
+ add_access_policy(self.bucket)
76
+ (
77
+ self.image_component_idx,
78
+ self.text_component_idx,
79
+ self.text_component2_idx,
80
+ ) = self._find_image_video_and_text_components(components)
81
+
82
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
83
+ """Sends flagged outputs and feedback to Gantry and image inputs to S3."""
84
+
85
+ image = flag_data[self.image_component_idx]
86
+ text = flag_data[self.text_component_idx]
87
+ text2 = flag_data[self.text_component2_idx]
88
+
89
+ feedback = {"flag": flag_option}
90
+ if username is not None:
91
+ feedback["user"] = username
92
+
93
+ data_type, image_buffer = read_b64_string(image, return_data_type=True)
94
+ image_url = self._to_s3(image_buffer.read(), filetype=data_type)
95
+
96
+ self._to_gantry(
97
+ input_image_url=image_url,
98
+ pred_location=text,
99
+ pred_coordinates=text2,
100
+ feedback=feedback,
101
+ )
102
+ self._counter += 1
103
+
104
+ return self._counter
105
+
106
+ def _to_gantry(self, input_image_url, pred_location, pred_coordinates, feedback):
107
+ inputs = {"image": input_image_url}
108
+ outputs = {"location": pred_location, "coordinates": pred_coordinates}
109
+
110
+ gantry.log_record(
111
+ self.application,
112
+ self.version,
113
+ inputs=inputs,
114
+ outputs=outputs,
115
+ feedback=feedback,
116
+ )
117
+
118
+ def _to_s3(self, image_bytes, key=None, filetype=None):
119
+ if key is None:
120
+ key = make_key(image_bytes, filetype=filetype)
121
+
122
+ s3_uri = get_uri_of(self.bucket, key)
123
+
124
+ with open(s3_uri, "wb") as s3_object:
125
+ s3_object.write(image_bytes)
126
+
127
+ return s3_uri
128
+
129
+ def _find_image_video_and_text_components(self, components: List[Component]):
130
+ """
131
+ Manual indexing of images and text components
132
+ """
133
+
134
+ image_component_idx = 0
135
+ text_component_idx = 1
136
+ text_component2_idx = 2
137
+
138
+ return (
139
+ image_component_idx,
140
+ text_component_idx,
141
+ text_component2_idx,
142
+ )
143
+
144
+
145
+ def get_api_key() -> Optional[str]:
146
+ """Convenience method for fetching the Gantry API key."""
147
+ api_key = os.environ.get("GANTRY_API_KEY")
148
+ return api_key
gantry_callback/s3_util.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility that uses boto to create buckets.
3
+ This work is not our own but is entirely written by https://github.com/full-stack-deep-learning.
4
+ """
5
+
6
+ import hashlib
7
+ import json
8
+
9
+ import boto3
10
+ import botocore
11
+
12
+ S3_URL_FORMAT = "https://{bucket}.s3.{region}.amazonaws.com/{key}"
13
+ S3_URI_FORMAT = "s3://{bucket}/{key}"
14
+
15
+ s3 = boto3.resource("s3")
16
+
17
+
18
+ def get_or_create_bucket(name):
19
+ """Gets an S3 bucket with boto3 or creates it if it doesn't exist."""
20
+ try: # try to create a bucket
21
+ name, response = _create_bucket(name)
22
+ except botocore.exceptions.ClientError as err:
23
+ # error handling from https://github.com/boto/boto3/issues/1195#issuecomment-495842252
24
+ status = err.response["ResponseMetadata"][
25
+ "HTTPStatusCode"
26
+ ] # status codes identify particular errors
27
+
28
+ if status == 409: # if the bucket exists already,
29
+ pass # we don't need to make it -- we presume we have the right permissions
30
+ else:
31
+ raise err
32
+
33
+ bucket = s3.Bucket(name)
34
+
35
+ return bucket
36
+
37
+
38
+ def _create_bucket(name):
39
+ """Creates a bucket with the provided name."""
40
+ session = boto3.session.Session() # sessions hold on to credentials and config
41
+ current_region = session.region_name # so we can pull the default region
42
+ bucket_config = {"LocationConstraint": current_region} # and apply it to the bucket
43
+
44
+ bucket_response = s3.create_bucket(
45
+ Bucket=name, CreateBucketConfiguration=bucket_config
46
+ )
47
+
48
+ return name, bucket_response
49
+
50
+
51
+ def make_key(fileobj, filetype=None):
52
+ """Creates a unique key for the fileobj and optionally append the filetype."""
53
+ identifier = make_identifier(fileobj)
54
+ if filetype is None:
55
+ return identifier
56
+ else:
57
+ return identifier + "." + filetype
58
+
59
+
60
+ def make_unique_bucket_name(prefix, seed):
61
+ """Creates a unique bucket name from a prefix and a seed."""
62
+ name = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:10]
63
+ return prefix + "-" + name
64
+
65
+
66
+ def get_url_of(bucket, key=None):
67
+ """Returns the url of a bucket and optionally of an object in that bucket."""
68
+ if not isinstance(bucket, str):
69
+ bucket = bucket.name
70
+ region = _get_region(bucket)
71
+ key = key or ""
72
+
73
+ url = _format_url(bucket, region, key)
74
+ return url
75
+
76
+
77
+ def get_uri_of(bucket, key=None):
78
+ """Returns the s3:// uri of a bucket and optionally of an object in that bucket."""
79
+ if not isinstance(bucket, str):
80
+ bucket = bucket.name
81
+ key = key or ""
82
+
83
+ uri = _format_uri(bucket, key)
84
+
85
+ return uri
86
+
87
+
88
+ def enable_bucket_versioning(bucket):
89
+ """Turns on versioning for bucket contents, which avoids deletion."""
90
+ if not isinstance(bucket, str):
91
+ bucket = bucket.name
92
+
93
+ bucket_versioning = s3.BucketVersioning(bucket)
94
+ return bucket_versioning.enable()
95
+
96
+
97
+ def add_access_policy(bucket):
98
+ """Adds a policy to our bucket that allows the Gantry app to access data."""
99
+ access_policy = json.dumps(_get_policy(bucket.name))
100
+ s3.meta.client.put_bucket_policy(Bucket=bucket.name, Policy=access_policy)
101
+
102
+
103
+ def _get_policy(bucket_name):
104
+ """Returns a bucket policy allowing Gantry app access as a JSON-compatible dictionary."""
105
+ return {
106
+ "Version": "2012-10-17",
107
+ "Statement": [
108
+ {
109
+ "Effect": "Allow",
110
+ "Principal": {
111
+ "AWS": [
112
+ "arn:aws:iam::848836713690:root",
113
+ "arn:aws:iam::339325199688:root",
114
+ "arn:aws:iam::665957668247:root",
115
+ ]
116
+ },
117
+ "Action": ["s3:GetObject", "s3:GetObjectVersion"],
118
+ "Resource": f"arn:aws:s3:::{bucket_name}/*",
119
+ },
120
+ {
121
+ "Effect": "Allow",
122
+ "Principal": {
123
+ "AWS": [
124
+ "arn:aws:iam::848836713690:root",
125
+ "arn:aws:iam::339325199688:root",
126
+ "arn:aws:iam::665957668247:root",
127
+ ]
128
+ },
129
+ "Action": "s3:ListBucketVersions",
130
+ "Resource": f"arn:aws:s3:::{bucket_name}",
131
+ },
132
+ ],
133
+ }
134
+
135
+
136
+ def make_identifier(byte_data):
137
+ """Create a unique identifier for a collection of bytes via hashing."""
138
+ # feed them to hashing algo -- security is not critical here, so we use SHA-1
139
+ hashed_data = hashlib.sha1(byte_data) # noqa: S3
140
+ identifier = hashed_data.hexdigest() # turn it into hexdecimal
141
+
142
+ return identifier
143
+
144
+
145
+ def _get_region(bucket):
146
+ """Determine the region of an s3 bucket."""
147
+ if not isinstance(bucket, str):
148
+ bucket = bucket.name
149
+
150
+ s3_client = boto3.client("s3")
151
+ bucket_location_response = s3_client.get_bucket_location(Bucket=bucket)
152
+ bucket_location = bucket_location_response["LocationConstraint"]
153
+
154
+ return bucket_location
155
+
156
+
157
+ def _format_url(bucket_name, region, key=None):
158
+ key = key or ""
159
+ url = S3_URL_FORMAT.format(bucket=bucket_name, region=region, key=key)
160
+ return url
161
+
162
+
163
+ def _format_uri(bucket_name, key=None):
164
+ key = key or ""
165
+ uri = S3_URI_FORMAT.format(bucket=bucket_name, key=key)
166
+ return uri
gantry_callback/string_img_util.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+
4
+
5
+ def read_b64_string(b64_string, return_data_type=False):
6
+ """Read a base64-encoded string into an in-memory file-like object."""
7
+ data_header, b64_data = split_and_validate_b64_string(b64_string)
8
+ b64_buffer = BytesIO(base64.b64decode(b64_data))
9
+ if return_data_type:
10
+ return get_b64_filetype(data_header), b64_buffer
11
+ else:
12
+ return b64_buffer
13
+
14
+
15
+ def get_b64_filetype(data_header):
16
+ """Retrieves the filetype information from the data type header of a base64-encoded object."""
17
+ _, file_type = data_header.split("/")
18
+ return file_type
19
+
20
+
21
+ def split_and_validate_b64_string(b64_string):
22
+ """Return the data_type and data of a b64 string, with validation."""
23
+ header, data = b64_string.split(",", 1)
24
+ assert header.startswith("data:")
25
+ assert header.endswith(";base64")
26
+ data_type = header.split(";")[0].split(":")[1]
27
+ return data_type, data