File size: 5,349 Bytes
b6654d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from PIL import ImageDraw
import numpy as np
import re


# Use a color map for bounding boxes
colormap = [
    "#0000FF",
    "#FFA500",
    "#008000",
    "#800080",
    "#A52A2A",
    "#FFC0CB",
    "#808080",
    "#808000",
    "#00FFFF",
    "#FF0000",
    "#00FF00",
    "#4B0082",
    "#4B0082",
    "#EE82EE",
    "#00FFFF",
    "#FF00FF",
    "#FF7F50",
    "#FFD700",
    "#87CEEB",
]


# Text cleaning function
def clean_text(text):
    """
    Cleans the given text by removing unwanted tokens, extra spaces,
    and ensures proper spacing between words and after periods.

    Args:
        text (str): The input text to be cleaned.

    Returns:
        str: The cleaned and properly formatted text.
    """

    # Remove unwanted tokens
    text = text.replace("<pad>", "").replace("</s>", "").strip()

    # Split the text into lines and clean each line
    lines = text.split("\n")
    cleaned_lines = [line.strip() for line in lines if line.strip()]

    # Join the cleaned lines into a single string with a space between each line
    cleaned_text = " ".join(cleaned_lines)

    # Ensure proper spacing between words and after periods using regex
    cleaned_text = re.sub(
        r"\s+", " ", cleaned_text
    )  # Replace multiple spaces with a single space
    cleaned_text = re.sub(
        r"(?<=[.])(?=[^\s])", r" ", cleaned_text
    )  # Add space after a period if not followed by a space

    # Return the cleaned text
    return cleaned_text


# Convert hex color to RGBA with the given alpha
def hex_to_rgba(hex_color, alpha):
    """
    Convert a hexadecimal color code to RGBA format.

    Args:
        hex_color (str): The hexadecimal color code (e.g., "#FF0000").
        alpha (int): The alpha value for the RGBA color (0-255).

    Returns:
        tuple: A tuple representing the RGBA color values (red, green, blue, alpha).
    """
    hex_color = hex_color.lstrip("#")
    r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
    return (r, g, b, alpha)


# Draw OCR bounding boxes with enhanced visual elements
def draw_ocr_bboxes(image, prediction):
    """
    Draw bounding boxes with enhanced visual elements on the given image based on the OCR prediction.

    Args:
        image (PIL.Image.Image): The input image on which the bounding boxes will be drawn.
        prediction (dict): The OCR prediction containing 'quad_boxes' and 'labels'.

    Returns:
        PIL.Image.Image: The image with the bounding boxes drawn.
    """

    # Create a drawing object for the image with RGBA mode
    draw = ImageDraw.Draw(image, "RGBA")

    # Extract bounding boxes and labels from the prediction
    bboxes, labels = prediction["quad_boxes"], prediction["labels"]

    for i, (box, label) in enumerate(zip(bboxes, labels)):
        # Select color for the bounding box and label
        color = colormap[i % len(colormap)]
        new_box = (np.array(box)).tolist()

        # Define the outline width and corner radius for the bounding box
        box_outline_width = 3
        corner_radius = 10

        # Draw rounded corners for the bounding box
        for j in range(4):
            start_x, start_y = new_box[j * 2], new_box[j * 2 + 1]
            end_x, end_y = new_box[(j * 2 + 2) % 8], new_box[(j * 2 + 3) % 8]

            # Draw the arcs for the rounded corners
            draw.arc(
                [
                    (start_x - corner_radius, start_y - corner_radius),
                    (start_x + corner_radius, start_y + corner_radius),
                ],
                90 + j * 90,
                180 + j * 90,
                fill=color,
                width=box_outline_width,
            )
            draw.arc(
                [
                    (end_x - corner_radius, end_y - corner_radius),
                    (end_x + corner_radius, end_y + corner_radius),
                ],
                j * 90,
                90 + j * 90,
                fill=color,
                width=box_outline_width,
            )

            # Draw the lines connecting the arcs
            if j in [0, 1, 2]:
                draw.line(
                    [
                        (start_x + corner_radius if j != 1 else start_x, start_y),
                        (end_x - corner_radius if j != 1 else end_x, end_y),
                    ],
                    fill=color,
                    width=box_outline_width,
                )
            else:
                draw.line(
                    [
                        (start_x, start_y + corner_radius),
                        (end_x, end_y - corner_radius),
                    ],
                    fill=color,
                    width=box_outline_width,
                )

        # Calculate the position for the text label
        text_x, text_y = min(new_box[0::2]), min(new_box[1::2]) - 20
        text_w, text_h = draw.textlength(label)
        rgba_color = hex_to_rgba(color, 200)  # Semi-transparent background for text

        # Draw the background rectangle for the text
        draw.rectangle(
            [text_x, text_y, text_x + text_w + 10, text_y + text_h + 10],
            fill=rgba_color,
        )

        # Draw the text label
        draw.text((text_x + 5, text_y + 5), label, fill=(0, 0, 0, 255))

    # Return the image with the OCR boxes drawn
    return image