sohojoe commited on
Commit
0b8f387
·
1 Parent(s): 91e4bde

add local test for comparison

Browse files
Files changed (1) hide show
  1. local_test.py +84 -0
local_test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+
5
+ num_steps = 1000
6
+ test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
7
+ clip_model="ViT-L/14"
8
+ clip_model_id ="laion5B-L-14"
9
+
10
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
+ print ("using device", device)
12
+
13
+ from clip_retrieval.load_clip import load_clip, get_tokenizer
14
+ # from clip_retrieval.clip_client import ClipClient, Modality
15
+ model, preprocess = load_clip(clip_model, use_jit=True, device=device)
16
+ tokenizer = get_tokenizer(clip_model)
17
+
18
+ def test_text(prompt):
19
+ text = tokenizer([prompt]).to(device)
20
+ with torch.no_grad():
21
+ prompt_embededdings = model.encode_text(text)
22
+ prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
23
+ return(prompt_embededdings)
24
+
25
+ def test_image(input_im):
26
+ input_im = Image.fromarray(input_im)
27
+ prepro = preprocess(input_im).unsqueeze(0).to(device)
28
+ with torch.no_grad():
29
+ image_embeddings = model.encode_image(prepro)
30
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
31
+ return(image_embeddings)
32
+
33
+ def test_preprocessed_image(prepro):
34
+ with torch.no_grad():
35
+ image_embeddings = model.encode_image(prepro)
36
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
37
+ return(image_embeddings)
38
+
39
+
40
+ # performance test for text
41
+ start = time.time()
42
+ for i in range(num_steps):
43
+ test_text("todo")
44
+ end = time.time()
45
+ average_time_seconds = (end - start) / num_steps
46
+ average_time_seconds = average_time_seconds if average_time_seconds > 0 else 0.0000001
47
+ print("Average time for text: ", average_time_seconds, "s")
48
+ print("Average time for text: ", average_time_seconds * 1000, "ms")
49
+ print("Number of predictions per second for text: ", 1 / average_time_seconds)
50
+
51
+ # download image from url
52
+ import requests
53
+ from PIL import Image
54
+ from io import BytesIO
55
+ response = requests.get(test_image_url)
56
+ input_image = Image.open(BytesIO(response.content))
57
+ input_image = input_image.convert('RGB')
58
+ # convert image to numpy array
59
+ input_image = np.array(input_image)
60
+
61
+ # performance test for image
62
+ start = time.time()
63
+ for i in range(num_steps):
64
+ test_image(input_image)
65
+ end = time.time()
66
+ average_time_seconds = (end - start) / num_steps
67
+ print("Average time for image: ", average_time_seconds, "s")
68
+ print("Average time for image: ", average_time_seconds * 1000, "ms")
69
+ print("Number of predictions per second for image: ", 1 / average_time_seconds)
70
+
71
+ # performance test for preprocessed image
72
+ input_im = Image.fromarray(input_image)
73
+ prepro = preprocess(input_im).unsqueeze(0).to(device)
74
+
75
+ start = time.time()
76
+ for i in range(num_steps):
77
+ test_preprocessed_image(prepro)
78
+ end = time.time()
79
+ average_time_seconds = (end - start) / num_steps
80
+ print("Average time for preprocessed image: ", average_time_seconds, "s")
81
+ print("Average time for preprocessed image: ", average_time_seconds * 1000, "ms")
82
+ print("Number of predictions per second for preprocessed image: ", 1 / average_time_seconds)
83
+
84
+