|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import json |
|
import os |
|
import time |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
|
|
from clip_app_client import ClipAppClient |
|
from clip_retrieval.clip_client import ClipClient, Modality |
|
clip_retrieval_service_url = "https://knn.laion.ai/knn-service" |
|
map_clip_to_clip_retreval = { |
|
"ViT-L/14": "laion5B-L-14", |
|
} |
|
|
|
|
|
def safe_url(url): |
|
import urllib.parse |
|
url = urllib.parse.quote(url, safe=':/') |
|
|
|
if url.count('.jpg') > 0: |
|
url = url.split('.jpg')[0] + '.jpg' |
|
return url |
|
|
|
def _safe_image_url_to_embedding(url, safe_return): |
|
try: |
|
return app_client.image_url_to_embedding(url) |
|
except: |
|
return safe_return |
|
|
|
def mean_template(embeddings): |
|
template = torch.mean(embeddings, dim=0, keepdim=True) |
|
return template |
|
|
|
def principal_component_analysis_template(embeddings): |
|
mean = torch.mean(embeddings, dim=0) |
|
embeddings_centered = embeddings - mean |
|
u, s, v = torch.svd(embeddings_centered) |
|
template = u[:, 0] |
|
return template |
|
|
|
def clustering_templates(embeddings, n_clusters=5): |
|
from sklearn.cluster import KMeans |
|
import numpy as np |
|
|
|
kmeans = KMeans(n_clusters=n_clusters) |
|
embeddings_np = embeddings.numpy() |
|
clusters = kmeans.fit_predict(embeddings_np) |
|
|
|
templates = [] |
|
for cluster in np.unique(clusters): |
|
cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0) |
|
templates.append(torch.from_numpy(cluster_mean)) |
|
return templates |
|
|
|
|
|
test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg") |
|
|
|
|
|
|
|
|
|
app_client = ClipAppClient() |
|
clip_retrieval_client = ClipClient( |
|
url=clip_retrieval_service_url, |
|
indice_name=map_clip_to_clip_retreval[app_client.clip_model], |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
preprocessed_image = app_client.preprocess_image(test_image_path) |
|
preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image) |
|
|
|
print (f"embeddings: {preprocessed_image_embeddings.shape}") |
|
|
|
|
|
template = preprocessed_image_embeddings |
|
for step_num in range(3): |
|
print (f"\n\n---- Step {step_num} ----") |
|
|
|
embedding_as_list = template[0].tolist() |
|
results = clip_retrieval_client.query(embedding_input=embedding_as_list) |
|
|
|
|
|
image_labels = [r['caption'] for r in results] |
|
image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels] |
|
image_label_vectors = torch.cat(image_label_vectors, dim=0) |
|
dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T) |
|
similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))] |
|
similarity_image_label.sort(reverse=True) |
|
for similarity, image_label in similarity_image_label: |
|
print (f"{similarity} {image_label}") |
|
|
|
|
|
image_urls = [safe_url(r['url']) for r in results] |
|
image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls] |
|
image_vectors = torch.cat(image_vectors, dim=0) |
|
dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T) |
|
similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))] |
|
similarity_image.sort(reverse=True) |
|
for similarity, image_label in similarity_image: |
|
print (f"{similarity} {image_label}") |
|
|
|
image_vectors = torch.stack([image_vectors[i] for i in range(len(image_vectors)) if similarity_image[i][0] > 0.001], dim=0) |
|
|
|
|
|
print(f"create a templates using clustering") |
|
merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0) |
|
|
|
|
|
clusters = clustering_templates(merged_embeddings, n_clusters=5) |
|
|
|
clusters = torch.stack(clusters, dim=0) |
|
dot_product = torch.mm(clusters, preprocessed_image_embeddings.T) |
|
cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))] |
|
cluster_similarity.sort(reverse=True) |
|
for similarity, idx in cluster_similarity: |
|
print (f"{similarity} {idx}") |
|
|
|
|
|
template = preprocessed_image_embeddings * (len(clusters)-1) |
|
for i in range(1, len(clusters)): |
|
template -= clusters[cluster_similarity[i][1]] |
|
print("---") |
|
print(f"seaching based on template") |
|
results = clip_retrieval_client.query(embedding_input=template[0].tolist()) |
|
hints = "" |
|
for result in results: |
|
url = safe_url(result["url"]) |
|
similarty = float("{:.4f}".format(result["similarity"])) |
|
title = result["caption"] |
|
print (f"{similarty} \"{title}\" {url}") |
|
if len(hints) > 0: |
|
hints += f", \"{title}\"" |
|
else: |
|
hints += f"\"{title}\"" |
|
print(hints) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|