Bastien Dechamps commited on
Commit
bd2d69d
1 Parent(s): 077dc3f
app.py CHANGED
@@ -2,14 +2,25 @@ import numpy as np
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
 
5
- from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
 
 
6
 
7
  ALL_GUESSR_CLASS = {
8
  "random": RandomGuessr,
 
9
  }
10
 
11
  ALL_GUESSR_ARGS = {
12
- "random": {}
 
 
 
 
 
 
 
 
13
  }
14
 
15
  # For instantiating guessrs only when needed
 
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
 
5
+ from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, GlobalEmbedderGuessr
6
+ from geoguessr_bot.retriever import DinoV2Embedder, Retriever
7
+
8
 
9
  ALL_GUESSR_CLASS = {
10
  "random": RandomGuessr,
11
+ "globalEmbedder": GlobalEmbedderGuessr
12
  }
13
 
14
  ALL_GUESSR_ARGS = {
15
+ "random": {},
16
+ "globalEmbedder": {
17
+ "embedder": DinoV2Embedder(
18
+ device="cpu"
19
+ ),
20
+ "retriever": Retriever(
21
+ embeddings_path="/home/bastiendechamps/geoguessr-bot/data/samples_embedded.npy"
22
+ )
23
+ }
24
  }
25
 
26
  # For instantiating guessrs only when needed
configs/embed_folder.yml CHANGED
@@ -4,4 +4,4 @@ command:
4
  (): geoguessr_bot.retriever.DinoV2Embedder
5
  device: "cpu"
6
  images_folder: !path "../data/samples"
7
- output_path: !path "../data/samples_embedded"
 
4
  (): geoguessr_bot.retriever.DinoV2Embedder
5
  device: "cpu"
6
  images_folder: !path "../data/samples"
7
+ output_path: !path "../data/samples_embedded.npy"
geoguessr_bot/guessr/global_embedder_guessr.py CHANGED
@@ -1,4 +1,7 @@
1
  from dataclasses import dataclass
 
 
 
2
  from geoguessr_bot.guessr import AbstractGuessr
3
  from geoguessr_bot.interfaces import Coordinate
4
  from geoguessr_bot.retriever import AbstractImageEmbedder
 
1
  from dataclasses import dataclass
2
+
3
+ from PIL import Image
4
+
5
  from geoguessr_bot.guessr import AbstractGuessr
6
  from geoguessr_bot.interfaces import Coordinate
7
  from geoguessr_bot.retriever import AbstractImageEmbedder
geoguessr_bot/retriever/abstract_embedder.py CHANGED
@@ -18,9 +18,9 @@ class AbstractImageEmbedder:
18
  """
19
  assert output_path.endswith(".npy"), "`output_path` must end with .npy"
20
  embeddings = {}
21
- for image in tqdm(os.listdir(folder_path)):
22
- image_path = os.path.join(folder_path, image)
23
  image = Image.open(image_path)
24
  embedding = self.embed(image)
25
- embeddings[image] = embedding
26
  np.save(output_path, embeddings)
 
18
  """
19
  assert output_path.endswith(".npy"), "`output_path` must end with .npy"
20
  embeddings = {}
21
+ for name in tqdm(os.listdir(folder_path)):
22
+ image_path = os.path.join(folder_path, name)
23
  image = Image.open(image_path)
24
  embedding = self.embed(image)
25
+ embeddings[name] = embedding
26
  np.save(output_path, embeddings)