qubvel-hf's picture
qubvel-hf HF staff
Update app.py
ec8f104 verified
raw
history blame
14.7 kB
import albumentations as A
import base64
import cv2
import logging
import gradio as gr
import inspect
import io
import numpy as np
import os
from dataclasses import dataclass
from copy import deepcopy
from functools import wraps
from PIL import Image, ImageDraw
from typing import get_type_hints, Optional
from mixpanel import Mixpanel
from utils import is_not_supported_transform
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN")
mp = Mixpanel(MIXPANEL_TOKEN)
HEADER = f"""
<div align="center">
<p>
<img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;">
<span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span>
</p>
<p style="margin-top: -15px;">
<a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a>
&nbsp;
<a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a>
</p>
</div>
"""
DEFAULT_TRANSFORM = "Rotate"
DEFAULT_IMAGE_PATH = "images/doctor.webp"
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH))
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0]
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1]
DEFAULT_BOXES = [
[265, 121, 326, 177], # Mask
[192, 169, 401, 395], # Coverall
]
mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]]
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]]
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]]
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints
BASE64_DEFAULT_MASKS = [
{
"label": "Coverall",
# light green color
"color": (144, 238, 144),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
},
{
"label": "Mask",
# light blue color
"color": (173, 216, 230),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
},
]
# Get all the transforms from the albumentations library
transforms_map = {
name: cls
for name, cls in vars(A).items()
if (
inspect.isclass(cls)
and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
and not is_not_supported_transform(cls)
)
}
transforms_map.pop("DualTransform", None)
transforms_map.pop("ImageOnlyTransform", None)
transforms_map.pop("ReferenceBasedTransform", None)
transforms_keys = list(sorted(transforms_map.keys()))
# Decode the masks
for mask in BASE64_DEFAULT_MASKS:
mask["mask"] = np.array(
Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")
)
@dataclass
class RequestParams:
user_ip: str
transform_name: Optional[str]
def track_event(event_name, user_id="unknown", properties=None):
if properties is None:
properties = {}
mp.track(user_id, event_name, properties)
logger.info(f"Event tracked: {event_name} - {properties}")
def get_params(request: gr.Request) -> RequestParams:
"""Parse input request parameters."""
ip = request.client.host
transform_name = request.query_params.get("transform", None)
params = RequestParams(user_ip=ip, transform_name=transform_name)
track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name})
return params
def run_with_retry(compose):
@wraps(compose)
def wrapper(*args, **kwargs):
processors = deepcopy(compose.processors)
for _ in range(4):
try:
result = compose(*args, **kwargs)
break
except NotImplementedError as e:
print(f"Caught NotImplementedError: {e}")
if "bbox" in str(e):
kwargs.pop("bboxes", None)
kwargs.pop("category_id", None)
compose.processors.pop("bboxes")
if "keypoint" in str(e):
kwargs.pop("keypoints", None)
compose.processors.pop("keypoints")
if "mask" in str(e):
kwargs.pop("mask", None)
except Exception as e:
compose.processors = processors
raise e
compose.processors = processors
return result
return wrapper
def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray:
"""Draw boxes with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for box in boxes:
x_min, y_min, x_max, y_max = box
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
return np.array(pil_image)
def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
"""Draw keypoints with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for keypoint in keypoints:
x, y = keypoint
draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
return np.array(pil_image)
def get_rgb_mask(masks):
"""Get the RGB mask from the binary mask."""
rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
for data in masks:
mask = data["mask"]
rgb_mask[mask > 0] = np.array(data["color"])
return rgb_mask
def draw_mask(image, mask):
"""Draw the mask on the image."""
image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
return image_with_mask
def draw_not_implemented_image(image: np.ndarray, annotation_type: str):
"""Draw the image with a text. In the middle."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
# align in the centerm, and make bigger font
text = f'Transform NOT working with "{annotation_type.upper()}" annotations.'
length = draw.textlength(text)
draw.text(
(DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
text,
fill=(255, 0, 0),
align="center",
)
return np.array(pil_image)
def get_formatted_signature(function_or_class, indentation=4):
signature = inspect.signature(function_or_class)
type_hints = get_type_hints(function_or_class)
args = []
for param in signature.parameters.values():
if param.name == "p":
str_param = "p=1.0,"
elif param.default == inspect.Parameter.empty:
str_param = f"{param.name}=,"
else:
if isinstance(param.default, str):
str_param = f'{param.name}="{param.default}",'
else:
str_param = f"{param.name}={param.default},"
annotation = type_hints.get(param.name, param.annotation)
if isinstance(param.annotation, type):
str_param += f" # {param.annotation.__name__}"
else:
str_annotation = str(annotation).replace("typing.", "")
str_param += f" # {str_annotation}"
str_param = "\n" + " " * indentation + str_param
args.append(str_param)
result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
return result
def get_formatted_transform(transform_name):
track_event("transform_selected", properties={"transform_name": transform_name})
transform = transforms_map[transform_name]
return f"A.{transform.__name__}{get_formatted_signature(transform)}"
def get_formatted_transform_docs(transform_name):
transform = transforms_map[transform_name]
return transform.__doc__.strip("\n")
def update_augmented_images(image, code):
augmentation = eval(code)
track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code})
compose = A.Compose(
[augmentation],
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
keypoint_params=A.KeypointParams(format="xy"),
)
compose = run_with_retry(compose) # to prevent NotImplementedError
keypoints = DEFAULT_KEYPOINTS
bboxes = DEFAULT_BOXES
mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
augmented = compose(
image=image,
not_implemented_image=image.copy(),
mask=mask,
keypoints=keypoints,
bboxes=bboxes,
category_id=range(len(bboxes)),
)
image = augmented["image"]
mask = augmented.get("mask", None)
bboxes = augmented.get("bboxes", None)
keypoints = augmented.get("keypoints", None)
# Draw the augmented images (or replace by placeholder if not implemented)
if mask is not None:
image_with_mask = draw_mask(image.copy(), mask)
else:
image_with_mask = draw_not_implemented_image(image.copy(), "mask")
if bboxes is not None:
image_with_bboxes = draw_boxes(image.copy(), bboxes)
else:
image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes")
if keypoints is not None:
image_with_keypoints = draw_keypoints(image.copy(), keypoints)
else:
image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints")
return [
(image_with_mask, "Mask"),
(image_with_bboxes, "Boxes"),
(image_with_keypoints, "Keypoints"),
]
def update_image_info(image):
h, w = image.shape[:2]
dtype = image.dtype
max_, min_ = image.max(), image.min()
return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
def update_code_and_docs(select):
code = get_formatted_transform(select)
docs = get_formatted_transform_docs(select)
return code, docs
def update_code_and_docs_on_start(url_params: gr.Request):
params = get_params(url_params)
transform_name = params.transform_name if params.transform_name is not None else DEFAULT_TRANSFORM
return gr.update(value=transform_name)
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
with gr.Group():
select = gr.Dropdown(
label="Select a transformation",
choices=transforms_keys,
value=DEFAULT_TRANSFORM,
type="value",
interactive=True,
)
with gr.Accordion("Documentation (click to expand)", open=False):
docs = gr.TextArea(
get_formatted_transform_docs(DEFAULT_TRANSFORM),
show_label=False,
interactive=False,
)
code = gr.Code(
label="Code",
language="python",
value=get_formatted_transform(DEFAULT_TRANSFORM),
interactive=True,
lines=5,
)
info = gr.TextArea(
value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)",
show_label=False,
lines=1,
max_lines=1,
)
button = gr.Button("Apply!")
image = gr.Image(
value=DEFAULT_IMAGE_PATH,
type="numpy",
height=500,
width=300,
sources=[],
)
with gr.Row():
augmented_image = gr.Gallery(
value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"),
rows=1,
columns=3,
show_label=False,
)
select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs])
button.click(
fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image]
)
demo.load(
update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False
)
if __name__ == "__main__":
demo.launch()