Bastien Dechamps commited on
Commit
1791df2
β€’
1 Parent(s): 9ed0050

[ADD] Average embedder

Browse files
app.py CHANGED
@@ -4,18 +4,20 @@ import numpy as np
4
  import gradio as gr
5
  import plotly.graph_objects as go
6
 
7
- from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, GlobalEmbedderGuessr
 
8
  from geoguessr_bot.retriever import DinoV2Embedder, Retriever
9
 
10
 
11
  ALL_GUESSR_CLASS = {
12
  "random": RandomGuessr,
13
- "globalEmbedder": GlobalEmbedderGuessr
 
14
  }
15
 
16
  ALL_GUESSR_ARGS = {
17
  "random": {},
18
- "globalEmbedder": {
19
  "embedder": DinoV2Embedder(
20
  device="cpu"
21
  ),
@@ -25,6 +27,19 @@ ALL_GUESSR_ARGS = {
25
  ),
26
  "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
27
  "resources/metadatav3.csv"),
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  }
29
  }
30
 
 
4
  import gradio as gr
5
  import plotly.graph_objects as go
6
 
7
+ from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
8
+ AverageNeighborsEmbedderGuessr
9
  from geoguessr_bot.retriever import DinoV2Embedder, Retriever
10
 
11
 
12
  ALL_GUESSR_CLASS = {
13
  "random": RandomGuessr,
14
+ "nearestNeighborEmbedder": NearestNeighborEmbedderGuessr,
15
+ "averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr,
16
  }
17
 
18
  ALL_GUESSR_ARGS = {
19
  "random": {},
20
+ "nearestNeighborEmbedder": {
21
  "embedder": DinoV2Embedder(
22
  device="cpu"
23
  ),
 
27
  ),
28
  "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
29
  "resources/metadatav3.csv"),
30
+ },
31
+ "averageNeighborsEmbedder": {
32
+ "embedder": DinoV2Embedder(
33
+ device="cpu"
34
+ ),
35
+ "retriever": Retriever(
36
+ embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
37
+ "resources/embeddings.npy"),
38
+ ),
39
+ "metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
40
+ "resources/metadatav3.csv"),
41
+ "n_neighbors": 2000,
42
+ "dbscan_eps": 0.5
43
  }
44
  }
45
 
geoguessr_bot/guessr/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .abstract_guessr import AbstractGuessr
2
  from .random_guessr import RandomGuessr
3
- from .global_embedder_guessr import GlobalEmbedderGuessr
 
 
1
  from .abstract_guessr import AbstractGuessr
2
  from .random_guessr import RandomGuessr
3
+ from .nearest_neighbor_embedder_guessr import NearestNeighborEmbedderGuessr
4
+ from .average_neighbor_embedder_guessr import AverageNeighborsEmbedderGuessr
geoguessr_bot/guessr/abstract_guessr.py CHANGED
@@ -25,7 +25,7 @@ class AbstractGuessr:
25
  """Create an interactive map showing a coordinate
26
  """
27
  fig = go.Figure(go.Scattermapbox(
28
- customdata=[guess_coordinate.__str__()] if guess_coordinate is not None else None,
29
  lat=[guess_coordinate.latitude] if guess_coordinate is not None else None,
30
  lon=[guess_coordinate.longitude] if guess_coordinate is not None else None,
31
  mode="markers",
 
25
  """Create an interactive map showing a coordinate
26
  """
27
  fig = go.Figure(go.Scattermapbox(
28
+ customdata=[str(guess_coordinate)] if guess_coordinate is not None else None,
29
  lat=[guess_coordinate.latitude] if guess_coordinate is not None else None,
30
  lon=[guess_coordinate.longitude] if guess_coordinate is not None else None,
31
  mode="markers",
geoguessr_bot/guessr/average_neighbor_embedder_guessr.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from sklearn.cluster import DBSCAN
6
+ from sklearn.metrics.pairwise import haversine_distances
7
+ from PIL import Image
8
+ import pandas as pd
9
+
10
+ from geoguessr_bot.guessr import AbstractGuessr
11
+ from geoguessr_bot.interfaces import Coordinate
12
+ from geoguessr_bot.retriever import AbstractImageEmbedder
13
+ from geoguessr_bot.retriever import Retriever
14
+
15
+
16
+ @dataclass
17
+ class AverageNeighborsEmbedderGuessr(AbstractGuessr):
18
+ """Guesses a coordinate using an Embedder and a retriever followed by NN.
19
+ """
20
+ embedder: AbstractImageEmbedder
21
+ retriever: Retriever
22
+ metadata_path: str
23
+ n_neighbors: int = 1000
24
+ dbscan_eps: float = 0.05
25
+
26
+ def __post_init__(self):
27
+ """Load metadata
28
+ """
29
+ metadata = pd.read_csv(self.metadata_path)
30
+ self.image_to_coordinate = {
31
+ image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
32
+ for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
33
+ }
34
+ # DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
35
+ self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distances)
36
+
37
+ def guess(self, image: Image) -> Coordinate:
38
+ """Guess a coordinate from an image
39
+ """
40
+ # Embed image
41
+ image = Image.fromarray(image)
42
+ image_embedding = self.embedder.embed(image)[None, :]
43
+
44
+ # Retrieve nearest neighbors
45
+ nearest_neighbors, distances = self.retriever.retrieve(image_embedding, self.n_neighbors)
46
+ nearest_neighbors = nearest_neighbors[0]
47
+ distances = distances[0]
48
+
49
+ # Get coordinates of neighbors
50
+ neighbors_coordinates = [self.image_to_coordinate[nn].to_radians() for nn in nearest_neighbors]
51
+ neighbors_coordinates = np.array([[nn.latitude, nn.longitude] for nn in neighbors_coordinates])
52
+
53
+ # Use DBSCAN to find the biggest cluster and potentially remove outliers
54
+ clustering = self.dbscan.fit(neighbors_coordinates)
55
+ labels = clustering.labels_
56
+ biggest_cluster = max(Counter(labels))
57
+ neighbors_coordinates = neighbors_coordinates[labels == biggest_cluster]
58
+ distances = distances[labels == biggest_cluster]
59
+
60
+ # Guess coordinate as the closest image among the cluster regarding retrieving distance
61
+ guess_coordinate = neighbors_coordinates[np.argmin(distances)]
62
+ guess_coordinate = Coordinate.from_radians(guess_coordinate[0], guess_coordinate[1])
63
+ return guess_coordinate
64
+
geoguessr_bot/guessr/{global_embedder_guessr.py β†’ nearest_neighbor_embedder_guessr.py} RENAMED
@@ -10,10 +10,9 @@ from geoguessr_bot.retriever import Retriever
10
 
11
 
12
  @dataclass
13
- class GlobalEmbedderGuessr(AbstractGuessr):
14
- """Guesses a coordinate using an Embedder and a retriever
15
  """
16
-
17
  embedder: AbstractImageEmbedder
18
  retriever: Retriever
19
  metadata_path: str
@@ -35,9 +34,9 @@ class GlobalEmbedderGuessr(AbstractGuessr):
35
  image = Image.fromarray(image)
36
  image_embedding = self.embedder.embed(image)[None, :]
37
 
38
- # Retrieve nearest neighbors
39
  nearest_neighbors = self.retriever.retrieve(image_embedding)
40
- nearest_neighbor = nearest_neighbors[0][0]
41
 
42
  # Guess coordinate
43
  guess_coordinate = self.image_to_coordinate[nearest_neighbor]
 
10
 
11
 
12
  @dataclass
13
+ class NearestNeighborEmbedderGuessr(AbstractGuessr):
14
+ """Guesses a coordinate using an Embedder and a retriever followed by NN.
15
  """
 
16
  embedder: AbstractImageEmbedder
17
  retriever: Retriever
18
  metadata_path: str
 
34
  image = Image.fromarray(image)
35
  image_embedding = self.embedder.embed(image)[None, :]
36
 
37
+ # Retrieve nearest neighbor
38
  nearest_neighbors = self.retriever.retrieve(image_embedding)
39
+ nearest_neighbor = nearest_neighbors[0][0][0]
40
 
41
  # Guess coordinate
42
  guess_coordinate = self.image_to_coordinate[nearest_neighbor]
geoguessr_bot/interfaces.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pydantic.main import BaseModel
2
 
3
 
@@ -7,3 +8,16 @@ class Coordinate(BaseModel):
7
 
8
  def __str__(self):
9
  return f"({round(self.latitude, 6)}, {round(self.longitude, 6)})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  from pydantic.main import BaseModel
3
 
4
 
 
8
 
9
  def __str__(self):
10
  return f"({round(self.latitude, 6)}, {round(self.longitude, 6)})"
11
+
12
+ def to_radians(self) -> 'Coordinate':
13
+ return Coordinate(
14
+ latitude=self.latitude * np.pi / 180.,
15
+ longitude=self.longitude * np.pi / 180.
16
+ )
17
+
18
+ @staticmethod
19
+ def from_radians(latitude: float, longitude: float) -> 'Coordinate':
20
+ return Coordinate(
21
+ latitude=latitude * 180. / np.pi,
22
+ longitude=longitude * 180. / np.pi
23
+ )
geoguessr_bot/retriever/retriever.py CHANGED
@@ -1,13 +1,12 @@
1
- from typing import Dict, List
2
 
3
  import numpy as np
4
  import faiss
5
 
6
 
7
  class Retriever:
8
- def __init__(self, embeddings_path: str, n_neighbors: int = 5):
9
  self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
10
- self.n_neighbors = n_neighbors
11
 
12
  # Keep track of image names
13
  self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
@@ -25,8 +24,8 @@ class Retriever:
25
  """
26
  return np.load(embeddings_path, allow_pickle=True).item()
27
 
28
- def retrieve(self, queries: np.ndarray) -> List[List[str]]:
29
  """Retrieve nearest neighbors indexes from queries
30
  """
31
- _, indexes = self.index.search(queries, self.n_neighbors)
32
- return [[self.index_to_image[i] for i in index] for index in indexes]
 
1
+ from typing import Dict, List, Tuple
2
 
3
  import numpy as np
4
  import faiss
5
 
6
 
7
  class Retriever:
8
+ def __init__(self, embeddings_path: str):
9
  self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
 
10
 
11
  # Keep track of image names
12
  self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
 
24
  """
25
  return np.load(embeddings_path, allow_pickle=True).item()
26
 
27
+ def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> Tuple[List[List[str]], List[List[float]]]:
28
  """Retrieve nearest neighbors indexes from queries
29
  """
30
+ distances, indexes = self.index.search(queries, n_neighbors)
31
+ return [[self.index_to_image[i] for i in index] for index in indexes], distances
requirements.txt CHANGED
@@ -12,3 +12,4 @@ torchvision
12
  tqdm
13
  configue
14
  fire
 
 
12
  tqdm
13
  configue
14
  fire
15
+ scikit-learn