File size: 4,839 Bytes
b6476a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04e9b0f
b6476a0
 
 
b685cf0
b6476a0
 
04e9b0f
b6476a0
 
 
 
b685cf0
b6476a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c3d937
 
 
 
b6476a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f19c28
 
 
 
5c3d937
 
 
 
 
 
 
 
 
 
 
4f19c28
 
5c3d937
4f19c28
 
 
 
5c3d937
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gzip
import io
import json
import random
import re
import tempfile
from typing import Dict, List, Optional

from PIL import Image
import requests
import streamlit as st


http_session = requests.Session()

API_URL = "https://world.openfoodfacts.org/api/v0"
PRODUCT_URL = API_URL + "/product"
OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products"
BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$")


@st.cache(allow_output_mutation=True)
def load_nn_data(url: str):
    r = http_session.get(url)
    with gzip.open(io.BytesIO(r.content), "rt") as f:
        return {int(key): value for key, value in json.loads(f.read()).items()}


@st.cache(allow_output_mutation=True)
def load_logo_data(url: str):
    r = http_session.get(url)
    with gzip.open(io.BytesIO(r.content), "rt") as f:
        return {
            int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, f))
        }


def get_image_from_url(
    image_url: str,
    error_raise: bool = False,
    session: Optional[requests.Session] = None,
) -> Optional[Image.Image]:
    if session:
        r = http_session.get(image_url)
    else:
        r = requests.get(image_url)

    if error_raise:
        r.raise_for_status()

    if r.status_code != 200:
        return None

    with tempfile.NamedTemporaryFile() as f:
        f.write(r.content)
        image = Image.open(f.name)

    return image


def split_barcode(barcode: str) -> List[str]:
    if not barcode.isdigit():
        raise ValueError("unknown barcode format: {}".format(barcode))

    match = BARCODE_PATH_REGEX.fullmatch(barcode)

    if match:
        return [x for x in match.groups() if x]

    return [barcode]


def get_cropped_image(barcode: str, image_id: str, bounding_box):
    image_path = generate_image_path(barcode, image_id)
    url = OFF_IMAGE_BASE_URL + image_path
    image = get_image_from_url(url, session=http_session)

    if image is None:
        return

    ymin, xmin, ymax, xmax = bounding_box
    (left, right, top, bottom) = (
        xmin * image.width,
        xmax * image.width,
        ymin * image.height,
        ymax * image.height,
    )
    return image.crop((left, top, right, bottom))


def generate_image_path(barcode: str, image_id: str) -> str:
    splitted_barcode = split_barcode(barcode)
    return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id)


def display_predictions(
    logo_data: Dict,
    nn_data: Dict,
    logo_id: Optional[int] = None,
):
    if not logo_id:
        logo_id = random.choice(list(nn_data.keys()))

    st.write(f"Logo ID: {logo_id}")
    logo = logo_data[logo_id]
    logo_nn_data = nn_data[logo_id]
    nn_ids = logo_nn_data["ids"]
    nn_distances = logo_nn_data["distances"]
    annotation = logo_nn_data["annotation"]

    cropped_image = get_cropped_image(
        logo["barcode"], logo["image_id"], logo["bounding_box"]
    )

    if cropped_image is None:
        return
    st.image(cropped_image, annotation, width=200)

    cropped_images: List[Image.Image] = []
    captions: List[str] = []
    progress_bar = st.progress(0)

    for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)):
        progress_bar.progress((i + 1) / len(nn_ids))
        closest_logo = logo_data[closest_id]

        cropped_image = get_cropped_image(
            closest_logo["barcode"],
            closest_logo["image_id"],
            closest_logo["bounding_box"],
        )
        if cropped_image is None:
            continue

        if cropped_image.height > cropped_image.width:
            cropped_image = cropped_image.rotate(90)

        cropped_images.append(cropped_image)
        captions.append(f"distance: {distance}")

    if cropped_images:
        st.image(cropped_images, captions, width=200)


st.sidebar.title("Logo Nearest Neighbors Demo")
st.sidebar.write(
    "Get first 100 nearest neighbors for a random annotated logo.\n\n"
    "CLIP model is used to generate embeddings, and nearest neighbors "
    "are computed either using a brute-force approach or with ANN."
)
logo_id = st.sidebar.number_input("logo ID", step=1) or None
approximate = (
    st.sidebar.checkbox(
        "ANN (HNSW)",
        value=False,
        help="Display approximate neighbors (instead of real "
        "neighbors computed using brute-force approach",
    )
    or None
)
nn_data = load_nn_data(
    f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz"
)
logo_data = load_logo_data(
    "https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz"
)
if approximate:
    st.write("Using approximate nearest neighbors method")
else:
    st.write("Using exact (brute-force) nearest neighbors method")
display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id)