File size: 13,677 Bytes
68bc627
 
 
 
f2d4b7d
68bc627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import streamlit as st
import pandas as pd
from PIL import Image
import torch
from diff_plonk.pipe import PlonkPipeline
from pathlib import Path
from streamlit_extras.colored_header import colored_header
import plotly.express as px
import requests
from io import BytesIO

# Set page config
st.set_page_config(
    page_title="Around the World in 80 Timesteps", page_icon="πŸ—ΊοΈ", layout="wide"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
# Define checkpoint path
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"

MODEL_NAMES = {
    "PLONK_YFCC": "nicolas-dufour/PLONK_YFCC",
    "PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M",
    "PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist",
}


@st.cache_resource
def load_model(model_name):
    """Load the model and cache it to prevent reloading"""
    try:
        pipe = PlonkPipeline(model_path=model_name)
        return pipe
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        st.stop()


PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES}


def predict_location(image, model_name, cfg=0.0, num_samples=256):
    with torch.no_grad():
        batch = {"img": [], "emb": []}

        # If image is already a PIL Image, use it directly
        if isinstance(image, Image.Image):
            img = image.convert("RGB")
        else:
            img = Image.open(image).convert("RGB")

        pipe = PIPES[model_name]

        # Get regular predictions
        predicted_gps = pipe(img, batch_size=num_samples, cfg=cfg, num_steps=32)

        # Get single high-confidence prediction
        high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=32)
        return {
            "lat": predicted_gps[:, 0].astype(float).tolist(),
            "lon": predicted_gps[:, 1].astype(float).tolist(),
            "high_conf_lat": high_conf_gps[0, 0].astype(float),
            "high_conf_lon": high_conf_gps[0, 1].astype(float),
        }


def load_example_images():
    """Load example images from the examples directory"""
    examples_dir = Path(__file__).parent / "examples"
    if not examples_dir.exists():
        st.error(
            """
            Examples directory not found. Please create the following structure:
            demo/
            └── examples/
                β”œβ”€β”€ eiffel_tower.jpg
                β”œβ”€β”€ colosseum.jpg
                β”œβ”€β”€ taj_mahal.jpg
                β”œβ”€β”€ statue_liberty.jpg
                └── sydney_opera.jpg
            """
        )
        return {}

    examples = {}
    for img_path in examples_dir.glob("*.jpg"):
        # Use filename without extension as the key
        name = img_path.stem.replace("_", " ").title()
        examples[name] = str(img_path)

    if not examples:
        st.warning("No example images found in the examples directory.")

    return examples


def resize_image_for_display(image, max_size=400):
    """Resize image while maintaining aspect ratio"""
    # Get current size
    width, height = image.size

    # Calculate ratio to maintain aspect ratio
    if width > height:
        if width > max_size:
            ratio = max_size / width
            new_size = (max_size, int(height * ratio))
    else:
        if height > max_size:
            ratio = max_size / height
            new_size = (int(width * ratio), max_size)

    # Only resize if image is larger than max_size
    if width > max_size or height > max_size:
        return image.resize(new_size, Image.Resampling.LANCZOS)
    return image


def load_image_from_url(url):
    """Load an image from a URL"""
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes
        return Image.open(BytesIO(response.content))
    except Exception as e:
        st.error(f"Error loading image from URL: {str(e)}")
        return None


def main():
    # Custom CSS
    st.markdown(
        """
        <style>
        .main {
            padding: 0rem 1rem;
        }
        .stButton>button {
            width: 100%;
            background-color: #FF4B4B;
            color: white;
            border: none;
            padding: 0.5rem 1rem;
            border-radius: 0.5rem;
        }
        .stButton>button:hover {
            background-color: #FF6B6B;
        }
        .prediction-box {
            background-color: #f0f2f6;
            padding: 1.5rem;
            border-radius: 0.5rem;
            margin: 1rem 0;
        }
        /* New styles for image containers */
        .upload-container {
            max-height: 300px;
            overflow-y: auto;
            margin-bottom: 1rem;
        }
        .examples-container {
            max-height: 200px;
            display: flex;
            gap: 10px;
        }
        .stTabs [data-baseweb="tab-panel"] {
            padding-top: 1rem;
        }
        </style>
    """,
        unsafe_allow_html=True,
    )

    # Header with custom styling
    colored_header(
        label="πŸ—ΊοΈ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation",
        description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess. <br> <br> Project page: https://nicolas-dufour.github.io/plonk",
        color_name="red-70",
    )

    # Adjust column ratio to give 2/3 of the space to the map
    col1, col2 = st.columns([1, 2], gap="large")

    with col1:
        # Add model selection before the sliders
        model_name = st.selectbox(
            "πŸ€– Select Model",
            options=MODEL_NAMES.keys(),
            index=0,  # Default to YFCC
            help="Choose which PLONK model variant to use for prediction.",
        )

        # Modify the slider columns to accommodate both controls
        col_slider1, col_slider2 = st.columns([0.5, 0.5])
        with col_slider1:
            cfg_value = st.slider(
                "🎯 Guidance scale",
                min_value=0.0,
                max_value=5.0,
                value=0.0,
                step=0.1,
                help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.",
            )

        with col_slider2:
            num_samples = st.number_input(
                "🎲 Number of samples",
                min_value=1,
                max_value=5000,
                value=1000,
                step=1,
                help="Number of location predictions to generate. More samples give better coverage but take longer to compute.",
            )

        st.markdown("### πŸ“Έ Choose your image")
        tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"])

        with tab1:
            uploaded_file = st.file_uploader(
                "Choose an image...",
                type=["png", "jpg", "jpeg"],
                help="Supported formats: PNG, JPG, JPEG",
            )

            if uploaded_file is not None:
                st.markdown('<div class="upload-container">', unsafe_allow_html=True)
                original_image = Image.open(uploaded_file)
                display_image = resize_image_for_display(
                    original_image.copy(), max_size=300
                )
                st.image(
                    display_image, caption="Uploaded Image", use_container_width=True
                )
                st.markdown("</div>", unsafe_allow_html=True)

                if st.button("πŸ” Predict Location", key="predict_upload"):
                    with st.spinner("🌍 Analyzing image and predicting locations..."):
                        predictions = predict_location(
                            original_image,
                            model_name=model_name,
                            cfg=cfg_value,
                            num_samples=num_samples,
                        )
                        st.session_state["predictions"] = predictions

        with tab2:
            url = st.text_input("Enter image URL:", key="image_url")

            if url:
                image = load_image_from_url(url)
                if image:
                    st.markdown(
                        '<div class="upload-container">', unsafe_allow_html=True
                    )
                    display_image = resize_image_for_display(image.copy(), max_size=300)
                    st.image(
                        display_image,
                        caption="Image from URL",
                        use_container_width=True,
                    )
                    st.markdown("</div>", unsafe_allow_html=True)

                    if st.button("πŸ” Predict Location", key="predict_url"):
                        with st.spinner(
                            "🌍 Analyzing image and predicting locations..."
                        ):
                            predictions = predict_location(
                                image,
                                model_name=model_name,
                                cfg=cfg_value,
                                num_samples=num_samples,
                            )
                            st.session_state["predictions"] = predictions

        with tab3:
            examples = load_example_images()
            st.markdown('<div class="examples-container">', unsafe_allow_html=True)
            example_cols = st.columns(len(examples))

            for idx, (name, path) in enumerate(examples.items()):
                with example_cols[idx]:
                    original_image = Image.open(path)
                    display_image = resize_image_for_display(
                        original_image.copy(), max_size=150
                    )

                    if st.container().button(
                        "πŸ“Έ",
                        key=f"img_{name}",
                        help=f"Click to predict location for {name}",
                        use_container_width=True,
                    ):
                        with st.spinner(
                            "🌍 Analyzing image and predicting locations..."
                        ):
                            predictions = predict_location(
                                original_image,
                                model_name=model_name,
                                cfg=cfg_value,
                                num_samples=num_samples,
                            )
                            st.session_state["predictions"] = predictions
                            st.rerun()

                    st.image(display_image, caption=name, use_container_width=True)
            st.markdown("</div>", unsafe_allow_html=True)

    with col2:
        st.markdown("### 🌍 Predicted Locations")

        if "predictions" in st.session_state:
            pred = st.session_state["predictions"]

            # Create DataFrame for all predictions
            df = pd.DataFrame(
                {
                    "lat": pred["lat"],
                    "lon": pred["lon"],
                    "type": ["Sample"] * len(pred["lat"]),
                }
            )

            # Add high-confidence prediction
            df = pd.concat(
                [
                    df,
                    pd.DataFrame(
                        {
                            "lat": [pred["high_conf_lat"]],
                            "lon": [pred["high_conf_lon"]],
                            "type": ["Best Guess"],
                        }
                    ),
                ]
            )

            # Create a more interactive map using Plotly
            fig = px.scatter_mapbox(
                df,
                lat="lat",
                lon="lon",
                zoom=2,
                opacity=0.6,
                color="type",
                color_discrete_map={"Sample": "blue", "Best Guess": "red"},
                mapbox_style="carto-positron",
            )

            fig.update_traces(selector=dict(name="Best Guess"), marker_size=15)

            fig.update_layout(
                margin={"r": 0, "t": 0, "l": 0, "b": 0},
                height=500,
                showlegend=True,
                legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
            )

            # Display map in a container
            with st.container():
                st.plotly_chart(fig, use_container_width=True)

            # Display stats in a styled container
            with st.container():
                st.markdown(
                    f"""
                    <div class="prediction-box">
                        <h4>πŸ“Š Prediction Statistics</h4>
                        <p>Number of sampled locations: {len(pred["lat"])}</p>
                        <p>Best guess location: {pred["high_conf_lat"]:.2f}Β°, {pred["high_conf_lon"]:.2f}Β°</p>
                    </div>
                    """,
                    unsafe_allow_html=True,
                )
        else:
            # Empty state with better styling
            st.markdown(
                """
                <div class="prediction-box" style="text-align: center;">
                    <h4>πŸ‘† Upload an image and click 'Predict Location'</h4>
                    <p>The predicted locations will appear here on an interactive map.</p>
                </div>
                """,
                unsafe_allow_html=True,
            )


if __name__ == "__main__":
    main()