Spaces:
Runtime error
Runtime error
""" | |
Class to handle flagging in Gradio to Gantry. | |
Originally written by the FSDL educators at https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/app_gradio/flagging.py | |
that has been adjusted for the geolocator project. | |
""" | |
import os | |
from typing import List, Optional, Union | |
import gantry | |
import gradio as gr | |
from gradio.components import Component | |
from smart_open import open | |
from .s3_util import ( | |
add_access_policy, | |
enable_bucket_versioning, | |
get_or_create_bucket, | |
get_uri_of, | |
make_key, | |
) | |
from .string_img_util import read_b64_string | |
class GantryImageToTextLogger(gr.FlaggingCallback): | |
""" | |
A FlaggingCallback that logs flagged image-to-text data to Gantry via S3. | |
""" | |
def __init__( | |
self, | |
application: str, | |
version: Union[int, str, None] = None, | |
api_key: Optional[str] = None, | |
): | |
"""Logs image-to-text data that was flagged in Gradio to Gantry. | |
Images are logged to Amazon Web Services' Simple Storage Service (S3). | |
The flagging_dir provided to the Gradio interface is used to set the | |
name of the bucket on S3 into which images are logged. | |
See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK | |
for Python, boto3: https://realpython.com/python-boto3-aws-s3/ | |
See https://gradio.app/docs/#flagging for details on how | |
flagging data is handled by Gradio. | |
See https://docs.gantry.io for information about logging data to Gantry. | |
Parameters | |
---------- | |
application | |
The name of the application on Gantry to which flagged data should be uploaded. | |
Gantry validates and monitors data per application. | |
version | |
The schema version to use during validation by Gantry. If not provided, Gantry | |
will use the latest version. A new version will be created if the provided version | |
does not exist yet. | |
api_key | |
Optionally, provide your Gantry API key here. Provided for convenience | |
when testing and developing locally or in notebooks. The API key can | |
alternatively be provided via the GANTRY_API_KEY environment variable. | |
""" | |
self.application = application | |
self.version = version | |
gantry.init(api_key=api_key) | |
def setup(self, components: List[Component], flagging_dir: str): | |
"""Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket.""" | |
self._counter = 0 | |
self.bucket = get_or_create_bucket(flagging_dir) | |
enable_bucket_versioning(self.bucket) | |
add_access_policy(self.bucket) | |
( | |
self.image_component_idx, | |
self.text_component_idx, | |
self.text_component2_idx, | |
) = self._find_image_video_and_text_components(components) | |
def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: | |
"""Sends flagged outputs and feedback to Gantry and image inputs to S3.""" | |
image = flag_data[self.image_component_idx] | |
text = flag_data[self.text_component_idx] | |
text2 = flag_data[self.text_component2_idx] | |
feedback = {"flag": flag_option} | |
if username is not None: | |
feedback["user"] = username | |
data_type, image_buffer = read_b64_string(image, return_data_type=True) | |
image_url = self._to_s3(image_buffer.read(), filetype=data_type) | |
self._to_gantry( | |
input_image_url=image_url, | |
pred_location=text, | |
pred_coordinates=text2, | |
feedback=feedback, | |
) | |
self._counter += 1 | |
return self._counter | |
def _to_gantry(self, input_image_url, pred_location, pred_coordinates, feedback): | |
inputs = {"image": input_image_url} | |
outputs = {"location": pred_location, "coordinates": pred_coordinates} | |
gantry.log_record( | |
self.application, | |
self.version, | |
inputs=inputs, | |
outputs=outputs, | |
feedback=feedback, | |
) | |
def _to_s3(self, image_bytes, key=None, filetype=None): | |
if key is None: | |
key = make_key(image_bytes, filetype=filetype) | |
s3_uri = get_uri_of(self.bucket, key) | |
with open(s3_uri, "wb") as s3_object: | |
s3_object.write(image_bytes) | |
return s3_uri | |
def _find_image_video_and_text_components(self, components: List[Component]): | |
""" | |
Manual indexing of images and text components | |
""" | |
image_component_idx = 0 | |
text_component_idx = 1 | |
text_component2_idx = 2 | |
return ( | |
image_component_idx, | |
text_component_idx, | |
text_component2_idx, | |
) | |
def get_api_key() -> Optional[str]: | |
"""Convenience method for fetching the Gantry API key.""" | |
api_key = os.environ.get("GANTRY_API_KEY") | |
return api_key | |