qubvel-hf's picture
qubvel-hf HF staff
Fix not implemented
e5b9884
raw
history blame
No virus
12 kB
import cv2
import inspect
import numpy as np
import albumentations as A
import gradio as gr
from typing import get_type_hints
from PIL import Image, ImageDraw
import base64
import io
from PIL import Image
from functools import wraps
from copy import deepcopy
DEFAULT_TRANSFORM = "CoarseDropout"
DEFAULT_IMAGE = "images/doctor.webp"
DEFAULT_IMAGE_HEIGHT = 400
DEFAULT_IMAGE_WIDTH = 600
DEFAULT_BOXES = [[265, 121, 326, 177], [192, 169, 401, 395]]
DEFAULT_KEYPOINTS = [
[(x_min + x_max) // 2, (y_min + y_max) // 2]
for x_min, y_min, x_max, y_max in DEFAULT_BOXES
]
CORENERS = [[[x_min, y_min], [x_max, y_max], [x_min, y_max], [x_max, y_min]] for x_min, y_min, x_max, y_max in DEFAULT_BOXES]
for bbox_corners in CORENERS:
DEFAULT_KEYPOINTS += bbox_corners
BASE64_DEFAULT_MASKS = [
{
"label": "Coverall",
# light green color
"color": (144, 238, 144),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
},
{
"label": "Mask",
# light blue color
"color": (173, 216, 230),
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
},
]
# Get all the transforms from the albumentations library
transforms_map = {
name: cls
for name, cls in vars(A).items()
if inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
}
transforms_map.pop("DualTransform", None)
transforms_map.pop("ImageOnlyTransform", None)
transforms_keys = list(sorted(transforms_map.keys()))
# Decode the masks
for mask in BASE64_DEFAULT_MASKS:
mask["mask"] = np.array(Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L"))
def run_with_retry(compose):
@wraps(compose)
def wrapper(*args, **kwargs):
processors = deepcopy(compose.processors)
for _ in range(4):
try:
result = compose(*args, **kwargs)
break
except NotImplementedError as e:
print(f"Caught NotImplementedError: {e}")
if "bbox" in str(e):
kwargs.pop("bboxes", None)
kwargs.pop("category_id", None)
compose.processors.pop("bboxes")
if "keypoint" in str(e):
kwargs.pop("keypoints", None)
compose.processors.pop("keypoints")
if "mask" in str(e):
kwargs.pop("mask", None)
except Exception as e:
compose.processors = processors
raise e
compose.processors = processors
return result
return wrapper
def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray:
"""Draw boxes with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for box in boxes:
x_min, y_min, x_max, y_max = box
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
return np.array(pil_image)
def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
"""Draw keypoints with PIL."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
for keypoint in keypoints:
x, y = keypoint
draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
return np.array(pil_image)
def get_rgb_mask(masks):
"""Get the RGB mask from the binary mask."""
rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
for data in masks:
mask = data["mask"]
rgb_mask[mask > 0] = np.array(data["color"])
return rgb_mask
def draw_mask(image, mask):
"""Draw the mask on the image."""
image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
return image_with_mask
def draw_not_implemented_image(image):
"""Draw the image with a text. In the middle."""
pil_image = Image.fromarray(image)
draw = ImageDraw.Draw(pil_image)
# align in the centerm, and make bigger font
text = "NOT IMPLEMETED FOR THIS TYPE OF ANNOTATIONS"
length = draw.textlength(text)
draw.text(
(DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
text,
fill=(255, 0, 0),
align="center",
)
return np.array(pil_image)
def get_formatted_signature(function_or_class, indentation=4):
signature = inspect.signature(function_or_class)
type_hints = get_type_hints(function_or_class)
args = []
for param in signature.parameters.values():
if param.name == "p":
str_param = "p=1.0,"
elif param.default == inspect.Parameter.empty:
str_param = f"{param.name}=,"
else:
if isinstance(param.default, str):
str_param = f'{param.name}="{param.default}",'
else:
str_param = f"{param.name}={param.default},"
annotation = type_hints.get(param.name, param.annotation)
if isinstance(param.annotation, type):
str_param += f" # {param.annotation.__name__}"
else:
str_annotation = str(annotation).replace("typing.", "")
str_param += f" # {str_annotation}"
str_param = "\n" + " " * indentation + str_param
args.append(str_param)
result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
return result
def update(image, code):
try:
augmentation = eval(code)
compose = A.Compose(
[augmentation],
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
keypoint_params=A.KeypointParams(format="xy"),
additional_targets={"not_implemented_image": "image"}
)
compose = run_with_retry(compose) # to prevent NotImplementedError
keypoints = DEFAULT_KEYPOINTS
bboxes = DEFAULT_BOXES
mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
augmented = compose(
image=image,
not_implemented_image=draw_not_implemented_image(image),
mask=mask,
keypoints=keypoints,
bboxes=bboxes,
category_id=range(len(bboxes)),
)
image = augmented["image"]
not_implemented_image = augmented["not_implemented_image"]
mask = augmented.get("mask", None)
bboxes = augmented.get("bboxes", None)
keypoints = augmented.get("keypoints", None)
image_with_mask = draw_mask(image.copy(), mask) if mask is not None else not_implemented_image
image_with_bboxes = draw_boxes(image.copy(), bboxes) if bboxes is not None else not_implemented_image
image_with_keypoints = draw_keypoints(image.copy(), keypoints) if keypoints is not None else not_implemented_image
return [
(image_with_mask, "Mask"),
(image_with_bboxes, "Boxes"),
(image_with_keypoints, "Keypoints"),
]
except Exception as e:
raise e
def update_image_info(image):
h, w = image.shape[:2]
dtype = image.dtype
max_, min_ = image.max(), image.min()
return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
def get_formatted_transform(transform_number):
transform_name = transforms_keys[transform_number]
transform = transforms_map[transform_name]
return f"A.{transform.__name__}{get_formatted_signature(transform)}"
def get_formatted_transform_docs(transform_number):
transform_name = transforms_keys[transform_number]
transform = transforms_map[transform_name]
return transform.__doc__.strip("\n")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Group():
select = gr.Dropdown(
label="Select a transformation",
choices=transforms_keys,
value=DEFAULT_TRANSFORM,
type="index",
interactive=True,
)
with gr.Accordion("Documentation", open=False):
docs = gr.TextArea(
get_formatted_transform_docs(
transforms_keys.index(DEFAULT_TRANSFORM)
),
show_label=False,
interactive=False,
)
code = gr.Code(
language="python",
value=get_formatted_transform(transforms_keys.index(DEFAULT_TRANSFORM)),
interactive=True,
lines=5,
)
button = gr.Button("Run")
#info = gr.Text(interactive=False, label="Image info", value="")
image = gr.Image(
value=DEFAULT_IMAGE,
type="numpy",
height=500,
width=300,
sources=[],
)
with gr.Row():
augmented_image = gr.Gallery(rows=1, columns=3)
# augmented_image = gr.Image(type="numpy", height=300, width=300)
#image.upload(fn=update_image_info, inputs=[image], outputs=[info])
select.change(fn=get_formatted_transform, inputs=[select], outputs=[code])
button.click(fn=update, inputs=[image, code], outputs=[augmented_image])
if __name__ == "__main__":
demo.launch()