tweak how i normalize
Browse files
app.py
CHANGED
@@ -37,7 +37,7 @@ def image_to_embedding(input_im):
|
|
37 |
prepro = preprocess(input_im).unsqueeze(0).to(device)
|
38 |
with torch.no_grad():
|
39 |
image_embeddings = model.encode_image(prepro)
|
40 |
-
|
41 |
image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy()
|
42 |
return image_embeddings_np
|
43 |
|
@@ -45,7 +45,7 @@ def prompt_to_embedding(prompt):
|
|
45 |
text = tokenizer([prompt]).to(device)
|
46 |
with torch.no_grad():
|
47 |
prompt_embededdings = model.encode_text(text)
|
48 |
-
|
49 |
prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy()
|
50 |
return prompt_embededdings_np
|
51 |
|
@@ -90,12 +90,13 @@ def main(
|
|
90 |
# dowload image
|
91 |
import requests
|
92 |
from io import BytesIO
|
93 |
-
response = requests.get(result["url"])
|
94 |
-
if not response.ok:
|
95 |
-
continue
|
96 |
try:
|
|
|
|
|
|
|
97 |
bytes = BytesIO(response.content)
|
98 |
image = Image.open(bytes)
|
|
|
99 |
images.append(image)
|
100 |
except Exception as e:
|
101 |
print(e)
|
@@ -145,7 +146,7 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
|
|
145 |
# final_embedding = final_embedding / num_embeddings
|
146 |
|
147 |
# normalize embeddings in numpy
|
148 |
-
|
149 |
|
150 |
embeddings_b64 = embedding_to_base64(final_embedding)
|
151 |
return embeddings_b64
|
@@ -292,7 +293,7 @@ with gr.Blocks() as demo:
|
|
292 |
with gr.Column(scale=5):
|
293 |
gr.Markdown(
|
294 |
"""
|
295 |
-
# Soho-Clip
|
296 |
|
297 |
A tool for exploring CLIP embedding spaces.
|
298 |
|
|
|
37 |
prepro = preprocess(input_im).unsqueeze(0).to(device)
|
38 |
with torch.no_grad():
|
39 |
image_embeddings = model.encode_image(prepro)
|
40 |
+
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
41 |
image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy()
|
42 |
return image_embeddings_np
|
43 |
|
|
|
45 |
text = tokenizer([prompt]).to(device)
|
46 |
with torch.no_grad():
|
47 |
prompt_embededdings = model.encode_text(text)
|
48 |
+
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
49 |
prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy()
|
50 |
return prompt_embededdings_np
|
51 |
|
|
|
90 |
# dowload image
|
91 |
import requests
|
92 |
from io import BytesIO
|
|
|
|
|
|
|
93 |
try:
|
94 |
+
response = requests.get(result["url"])
|
95 |
+
if not response.ok:
|
96 |
+
continue
|
97 |
bytes = BytesIO(response.content)
|
98 |
image = Image.open(bytes)
|
99 |
+
image.title = str(result["similarity"]) + ' ' + result["caption"]
|
100 |
images.append(image)
|
101 |
except Exception as e:
|
102 |
print(e)
|
|
|
146 |
# final_embedding = final_embedding / num_embeddings
|
147 |
|
148 |
# normalize embeddings in numpy
|
149 |
+
final_embedding /= np.linalg.norm(final_embedding)
|
150 |
|
151 |
embeddings_b64 = embedding_to_base64(final_embedding)
|
152 |
return embeddings_b64
|
|
|
293 |
with gr.Column(scale=5):
|
294 |
gr.Markdown(
|
295 |
"""
|
296 |
+
# Soho-Clip Embedding Explorer
|
297 |
|
298 |
A tool for exploring CLIP embedding spaces.
|
299 |
|