Spaces:
Runtime error
Runtime error
Bastien Dechamps
commited on
Commit
•
bd2d69d
1
Parent(s):
077dc3f
debug
Browse files
app.py
CHANGED
@@ -2,14 +2,25 @@ import numpy as np
|
|
2 |
import gradio as gr
|
3 |
import plotly.graph_objects as go
|
4 |
|
5 |
-
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
|
|
|
|
|
6 |
|
7 |
ALL_GUESSR_CLASS = {
|
8 |
"random": RandomGuessr,
|
|
|
9 |
}
|
10 |
|
11 |
ALL_GUESSR_ARGS = {
|
12 |
-
"random": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
}
|
14 |
|
15 |
# For instantiating guessrs only when needed
|
|
|
2 |
import gradio as gr
|
3 |
import plotly.graph_objects as go
|
4 |
|
5 |
+
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, GlobalEmbedderGuessr
|
6 |
+
from geoguessr_bot.retriever import DinoV2Embedder, Retriever
|
7 |
+
|
8 |
|
9 |
ALL_GUESSR_CLASS = {
|
10 |
"random": RandomGuessr,
|
11 |
+
"globalEmbedder": GlobalEmbedderGuessr
|
12 |
}
|
13 |
|
14 |
ALL_GUESSR_ARGS = {
|
15 |
+
"random": {},
|
16 |
+
"globalEmbedder": {
|
17 |
+
"embedder": DinoV2Embedder(
|
18 |
+
device="cpu"
|
19 |
+
),
|
20 |
+
"retriever": Retriever(
|
21 |
+
embeddings_path="/home/bastiendechamps/geoguessr-bot/data/samples_embedded.npy"
|
22 |
+
)
|
23 |
+
}
|
24 |
}
|
25 |
|
26 |
# For instantiating guessrs only when needed
|
configs/embed_folder.yml
CHANGED
@@ -4,4 +4,4 @@ command:
|
|
4 |
(): geoguessr_bot.retriever.DinoV2Embedder
|
5 |
device: "cpu"
|
6 |
images_folder: !path "../data/samples"
|
7 |
-
output_path: !path "../data/samples_embedded"
|
|
|
4 |
(): geoguessr_bot.retriever.DinoV2Embedder
|
5 |
device: "cpu"
|
6 |
images_folder: !path "../data/samples"
|
7 |
+
output_path: !path "../data/samples_embedded.npy"
|
geoguessr_bot/guessr/global_embedder_guessr.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
from dataclasses import dataclass
|
|
|
|
|
|
|
2 |
from geoguessr_bot.guessr import AbstractGuessr
|
3 |
from geoguessr_bot.interfaces import Coordinate
|
4 |
from geoguessr_bot.retriever import AbstractImageEmbedder
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
from geoguessr_bot.guessr import AbstractGuessr
|
6 |
from geoguessr_bot.interfaces import Coordinate
|
7 |
from geoguessr_bot.retriever import AbstractImageEmbedder
|
geoguessr_bot/retriever/abstract_embedder.py
CHANGED
@@ -18,9 +18,9 @@ class AbstractImageEmbedder:
|
|
18 |
"""
|
19 |
assert output_path.endswith(".npy"), "`output_path` must end with .npy"
|
20 |
embeddings = {}
|
21 |
-
for
|
22 |
-
image_path = os.path.join(folder_path,
|
23 |
image = Image.open(image_path)
|
24 |
embedding = self.embed(image)
|
25 |
-
embeddings[
|
26 |
np.save(output_path, embeddings)
|
|
|
18 |
"""
|
19 |
assert output_path.endswith(".npy"), "`output_path` must end with .npy"
|
20 |
embeddings = {}
|
21 |
+
for name in tqdm(os.listdir(folder_path)):
|
22 |
+
image_path = os.path.join(folder_path, name)
|
23 |
image = Image.open(image_path)
|
24 |
embedding = self.embed(image)
|
25 |
+
embeddings[name] = embedding
|
26 |
np.save(output_path, embeddings)
|