Spaces:
Runtime error
Runtime error
| import os.path | |
| import numpy as np | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \ | |
| AverageNeighborsEmbedderGuessr | |
| from geoguessr_bot.retriever import DinoV2Embedder, Retriever | |
| ALL_GUESSR_CLASS = { | |
| "random": RandomGuessr, | |
| "nearestNeighborEmbedder": NearestNeighborEmbedderGuessr, | |
| "averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr, | |
| } | |
| ALL_GUESSR_ARGS = { | |
| "random": {}, | |
| "nearestNeighborEmbedder": { | |
| "embedder": DinoV2Embedder( | |
| device="cpu" | |
| ), | |
| "retriever": Retriever( | |
| embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "resources/embeddings.npy"), | |
| ), | |
| "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "resources/metadatav3.csv"), | |
| }, | |
| "averageNeighborsEmbedder": { | |
| "embedder": DinoV2Embedder( | |
| device="cpu" | |
| ), | |
| "retriever": Retriever( | |
| embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "resources/embeddings.npy"), | |
| ), | |
| "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
| "resources/metadatav3.csv"), | |
| "n_neighbors": 2000, | |
| "dbscan_eps": 0.5 | |
| } | |
| } | |
| # For instantiating guessrs only when needed | |
| ALL_GUESSR = {} | |
| def create_map(guessr: str) -> go.Figure: | |
| """Create an interactive map | |
| """ | |
| # Instantiate guessr if not already done | |
| if guessr not in ALL_GUESSR: | |
| ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) | |
| return AbstractGuessr.create_map() | |
| def guess(guessr: str, uploaded_image) -> go.Figure: | |
| """Guess a coordinate from an image uploaded in the Gradio interface | |
| """ | |
| # Instantiate guessr if not already done | |
| if guessr not in ALL_GUESSR: | |
| ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) | |
| # Convert image to numpy array | |
| uploaded_image = np.array(uploaded_image) | |
| # Guess coordinate | |
| guess_coordinate = ALL_GUESSR[guessr].guess(uploaded_image) | |
| # Create map | |
| fig = ALL_GUESSR[guessr].create_map(guess_coordinate) | |
| return fig | |
| if __name__ == "__main__": | |
| # Create & launch Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| guessr_dropdown = gr.Dropdown( | |
| list(ALL_GUESSR_CLASS.keys()), | |
| value="globalEmbedder", | |
| label="Guessr type", | |
| info="More Guessr types will be added soon!" | |
| ) | |
| image = gr.Image(shape=(800, 800)) | |
| button = gr.Button(text="Guess") | |
| interactive_map = gr.Plot() | |
| demo.load(create_map, [guessr_dropdown], interactive_map) | |
| button.click(guess, [guessr_dropdown, image], interactive_map) | |
| # Launch demo π | |
| demo.launch() | |