Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as T
|
6 |
+
import pandas as pd
|
7 |
+
from sklearn.neighbors import NearestNeighbors
|
8 |
+
from sklearn.cluster import DBSCAN
|
9 |
+
from shapely.geometry import Point
|
10 |
+
import geopandas as gpd
|
11 |
+
from geopandas import GeoDataFrame
|
12 |
+
|
13 |
+
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")
|
14 |
+
model.eval()
|
15 |
+
|
16 |
+
metadata = pd.read_csv("data/streetview_v3/metadatav3.csv")
|
17 |
+
metadata.path = metadata.path.apply(lambda x: x.split("/")[-1])
|
18 |
+
|
19 |
+
PATH = "data/streetview_v3/images/"
|
20 |
+
PATH_TEST = "data/test-competition/images/images/"
|
21 |
+
|
22 |
+
embeddings = np.load("data/embeddings.npy")
|
23 |
+
test_embeddings = np.load("data/test_embeddings.npy")
|
24 |
+
files = open("data/files.txt").read().split("\n")
|
25 |
+
test_files = open("data/test_files.txt").read().split("\n")
|
26 |
+
print(embeddings.shape, test_embeddings.shape, len(files), len(test_files))
|
27 |
+
|
28 |
+
knn = NearestNeighbors(n_neighbors=50, algorithm='kd_tree', n_jobs=8)
|
29 |
+
knn.fit(embeddings)
|
30 |
+
|
31 |
+
# %%
|
32 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
33 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
34 |
+
|
35 |
+
transform = T.Compose([
|
36 |
+
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
|
37 |
+
T.CenterCrop(224),
|
38 |
+
T.ToTensor(),
|
39 |
+
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
40 |
+
])
|
41 |
+
|
42 |
+
|
43 |
+
def cluster(df, eps=0.1, min_samples=5, metric="cosine", n_jobs=8, show=False):
|
44 |
+
if len(df) == 1:
|
45 |
+
return df
|
46 |
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
|
47 |
+
dbscan.fit(df[["longitude", "latitude"]])
|
48 |
+
df["cluster"] = dbscan.labels_
|
49 |
+
# Return centroid of the cluster with the most points
|
50 |
+
df = df[df.cluster == df.cluster.value_counts().index[0]]
|
51 |
+
df = df.groupby("cluster").apply(lambda x: x[["longitude", "latitude"]].median()).reset_index()
|
52 |
+
# Return coordinates of the cluster with the most points
|
53 |
+
return df.longitude.iloc[0], df.latitude.iloc[0]
|
54 |
+
|
55 |
+
|
56 |
+
def guess_image(img):
|
57 |
+
# img = Image.open(image_path)
|
58 |
+
# cast as rgb
|
59 |
+
img = img.convert('RGB')
|
60 |
+
print(img)
|
61 |
+
with torch.no_grad():
|
62 |
+
features = model(transform(img).to("cuda").unsqueeze(0))[0].cpu()
|
63 |
+
distances, neighbors = knn.kneighbors(features.unsqueeze(0))
|
64 |
+
|
65 |
+
neighbors = neighbors[0]
|
66 |
+
# Return metadata df rows with neighbors
|
67 |
+
df = pd.DataFrame()
|
68 |
+
for n in neighbors:
|
69 |
+
df = pd.concat([df, metadata[metadata.path == files[n]]])
|
70 |
+
coords = cluster(df, eps=0.005, min_samples=5)
|
71 |
+
|
72 |
+
geometry = [Point(xy) for xy in zip(df['longitude'], df['latitude'])]
|
73 |
+
gdf = GeoDataFrame(df, geometry=geometry)
|
74 |
+
gdf_guess = GeoDataFrame(df[:1], geometry=[Point(coords)])
|
75 |
+
# this is a simple map that goes with geopandas
|
76 |
+
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
|
77 |
+
plot_ = world.plot(figsize=(10, 6))
|
78 |
+
gdf.plot(ax=plot_, marker='o', color='red', markersize=15)
|
79 |
+
gdf_guess.plot(ax=plot_, marker='o', color='blue', markersize=15);
|
80 |
+
return coords, plot_.figure
|
81 |
+
|
82 |
+
|
83 |
+
# Image to image translation
|
84 |
+
def translate_image(input_image):
|
85 |
+
coords, fig = guess_image(Image.fromarray(input_image.astype('uint8'), 'RGB'))
|
86 |
+
fig.savefig("tmp.png")
|
87 |
+
return str(coords), np.array(Image.open("tmp.png").convert("RGB"))
|
88 |
+
|
89 |
+
|
90 |
+
demo = gr.Interface(fn=translate_image, inputs="image", outputs=["text", "image"], title="Street View Location")
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
demo.launch()
|