sohojoe commited on
Commit
a14ceae
·
1 Parent(s): dcd6afb

prompt, from image working

Browse files
Files changed (1) hide show
  1. app.py +14 -63
app.py CHANGED
@@ -33,50 +33,21 @@ embedding_base64s = [None for i in range(max_tabs)]
33
 
34
 
35
  def image_to_embedding(input_im):
36
- # approch A:
37
- tform = transforms.Compose([
38
- transforms.ToTensor(),
39
- transforms.Resize(
40
- (336, 336),
41
- interpolation=transforms.InterpolationMode.BICUBIC,
42
- antialias=False,
43
- ),
44
- transforms.Normalize(
45
- [0.48145466, 0.4578275, 0.40821073],
46
- [0.26862954, 0.26130258, 0.27577711]),
47
- ])
48
- input = tform(input_im).to(device)
49
-
50
- # approch B: convert input_im to torch
51
- # inp = torch.from_numpy(np.array(input_im)).to(device)
52
- # inp = torch.from_numpy(np.array(input_im)).permute(2, 0, 1).to(device)
53
-
54
- # dtype = torch.float32
55
- # input = input.to(device=device, dtype=dtype)
56
- input = input.unsqueeze(0)
57
- # image_embeddings = pipe.image_encoder(image).image_embeds
58
- # image_embeddings = image_embeddings[0]
59
-
60
  with torch.no_grad():
61
- # image_embeddings_np = model.get_text_features(prompt_tokens.to(device))
62
- image_embeddings = model.get_image_features(input)
63
-
64
  # image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
65
- image_embeddings_np = image_embeddings.cpu().detach().numpy()
66
  return image_embeddings_np
67
 
68
  def prompt_to_embedding(prompt):
69
- # inputs = processor(prompt, images=imgs, return_tensors="pt", padding=True)
70
- inputs = processor(prompt, return_tensors="pt", padding='max_length', max_length=77)
71
- # labels = torch.tensor(labels)
72
- # prompt_tokens = inputs.input_ids[0]
73
- prompt_tokens = inputs.input_ids
74
- # image = inputs.pixel_values
75
  with torch.no_grad():
76
- prompt_embededdings = model.get_text_features(prompt_tokens.to(device))
77
  # prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
78
- prompt_embededdings = prompt_embededdings[0].cpu().detach().numpy()
79
- return prompt_embededdings
80
 
81
  def embedding_to_image(embeddings):
82
  size = math.ceil(math.sqrt(embeddings.shape[0]))
@@ -87,15 +58,15 @@ def embedding_to_image(embeddings):
87
 
88
  def embedding_to_base64(embeddings):
89
  import base64
90
- # ensure float16
91
- embeddings = embeddings.astype(np.float16)
92
  embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode()
93
  return embeddings_b64
94
 
95
  def base64_to_embedding(embeddings_b64):
96
  import base64
97
  embeddings = base64.urlsafe_b64decode(embeddings_b64)
98
- embeddings = np.frombuffer(embeddings, dtype=np.float16)
99
  # embeddings = torch.tensor(embeddings)
100
  return embeddings
101
 
@@ -177,6 +148,9 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
177
  # TODO toggle this to support average or sum
178
  final_embedding = final_embedding / num_embeddings
179
 
 
 
 
180
  embeddings_b64 = embedding_to_base64(final_embedding)
181
  return embeddings_b64
182
 
@@ -229,35 +203,12 @@ def on_example_image_click_set_image(input_image, image_url):
229
 
230
  # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
231
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
232
- torch_size = torch.float16 if device == ('cuda') else torch.float32
233
- # torch_size = torch.float32
234
- # pipe = StableDiffusionPipeline.from_pretrained(
235
- # model_id,
236
- # custom_pipeline="pipeline.py",
237
- # torch_dtype=torch_size,
238
- # # , revision="fp16",
239
- # requires_safety_checker = False, safety_checker=None,
240
- # text_encoder = CLIPTextModel,
241
- # tokenizer = CLIPTokenizer,
242
- # )
243
- # pipe = pipe.to(device)
244
-
245
- from transformers import AutoProcessor, AutoModel
246
- # processor = AutoProcessor.from_pretrained(clip_model_id)
247
- # model = AutoModel.from_pretrained(clip_model_id)
248
- # model = model.to(device)
249
 
250
  from clip_retrieval.load_clip import load_clip, get_tokenizer
251
  # model, preprocess = load_clip(clip_model, use_jit=True, device=device)
252
  model, preprocess = load_clip(clip_model, use_jit=True, device=device)
253
  tokenizer = get_tokenizer(clip_model)
254
 
255
- test_url = "https://placekitten.com/400/600"
256
- test_caption = "an image of a cat"
257
- test_image_1 = "tests/test_clip_inference/test_images/123_456.jpg"
258
- test_image_2 = "tests/test_clip_inference/test_images/416_264.jpg"
259
-
260
- # clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
261
  clip_retrieval_client = ClipClient(
262
  url=clip_retrieval_service_url,
263
  indice_name=clip_model_id,
 
33
 
34
 
35
  def image_to_embedding(input_im):
36
+ input_im = Image.fromarray(input_im)
37
+ prepro = preprocess(input_im).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  with torch.no_grad():
39
+ image_embeddings = model.encode_image(prepro)
 
 
40
  # image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
41
+ image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy()
42
  return image_embeddings_np
43
 
44
  def prompt_to_embedding(prompt):
45
+ text = tokenizer([prompt]).to(device)
 
 
 
 
 
46
  with torch.no_grad():
47
+ prompt_embededdings = model.encode_text(text)
48
  # prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
49
+ prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy()
50
+ return prompt_embededdings_np
51
 
52
  def embedding_to_image(embeddings):
53
  size = math.ceil(math.sqrt(embeddings.shape[0]))
 
58
 
59
  def embedding_to_base64(embeddings):
60
  import base64
61
+ # ensure float32
62
+ embeddings = embeddings.astype(np.float32)
63
  embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode()
64
  return embeddings_b64
65
 
66
  def base64_to_embedding(embeddings_b64):
67
  import base64
68
  embeddings = base64.urlsafe_b64decode(embeddings_b64)
69
+ embeddings = np.frombuffer(embeddings, dtype=np.float32)
70
  # embeddings = torch.tensor(embeddings)
71
  return embeddings
72
 
 
148
  # TODO toggle this to support average or sum
149
  final_embedding = final_embedding / num_embeddings
150
 
151
+ # normalize embeddings in numpy
152
+ final_embedding /= np.linalg.norm(final_embedding)
153
+
154
  embeddings_b64 = embedding_to_base64(final_embedding)
155
  return embeddings_b64
156
 
 
203
 
204
  # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
205
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  from clip_retrieval.load_clip import load_clip, get_tokenizer
208
  # model, preprocess = load_clip(clip_model, use_jit=True, device=device)
209
  model, preprocess = load_clip(clip_model, use_jit=True, device=device)
210
  tokenizer = get_tokenizer(clip_model)
211
 
 
 
 
 
 
 
212
  clip_retrieval_client = ClipClient(
213
  url=clip_retrieval_service_url,
214
  indice_name=clip_model_id,