Spaces:
Runtime error
Runtime error
JavierFnts
commited on
Commit
β’
0454d20
1
Parent(s):
30c8dd0
prompt ranking working
Browse files- streamlit_app.py β app.py +95 -92
- clip_model.py +13 -16
- requirements.txt +3 -2
streamlit_app.py β app.py
RENAMED
@@ -4,9 +4,6 @@ import requests
|
|
4 |
import streamlit as st
|
5 |
from clip_model import ClipModel
|
6 |
|
7 |
-
from session_state import SessionState, get_state
|
8 |
-
from images_mocker import ImagesMocker
|
9 |
-
|
10 |
from PIL import Image
|
11 |
|
12 |
IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
|
@@ -32,24 +29,34 @@ def load_image_from_url(url: str) -> Image.Image:
|
|
32 |
def load_model() -> ClipModel:
|
33 |
return ClipModel()
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
def limit_number_images(
|
37 |
"""When moving between tasks sometimes the state of images can have too many samples"""
|
38 |
-
if
|
39 |
-
|
40 |
|
41 |
|
42 |
-
def limit_number_prompts(
|
43 |
"""When moving between tasks sometimes the state of prompts can have too many samples"""
|
44 |
-
if
|
45 |
-
|
46 |
|
47 |
|
48 |
-
def is_valid_prediction_state(
|
49 |
-
if
|
50 |
st.error("Choose at least one image before predicting")
|
51 |
return False
|
52 |
-
if
|
53 |
st.error("Write at least one prompt before predicting")
|
54 |
return False
|
55 |
return True
|
@@ -97,16 +104,16 @@ class Sections:
|
|
97 |
st.markdown("### Try OpenAI's CLIP model in your browser")
|
98 |
st.markdown(" ")
|
99 |
st.markdown(" ")
|
100 |
-
with st.
|
101 |
st.markdown("CLIP is a machine learning model that computes similarity between text "
|
102 |
"(also called prompts) and images. It has been trained on a dataset with millions of diverse"
|
103 |
" image-prompt pairs, which allows it to generalize to unseen examples."
|
104 |
" <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
|
105 |
unsafe_allow_html=True)
|
106 |
-
col1, col2 = st.
|
107 |
col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
|
108 |
col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
|
109 |
-
with st.
|
110 |
st.markdown("#### Prompt ranking")
|
111 |
st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
|
112 |
st.markdown("#### Image ranking")
|
@@ -118,7 +125,7 @@ class Sections:
|
|
118 |
st.markdown(" ")
|
119 |
|
120 |
@staticmethod
|
121 |
-
def image_uploader(
|
122 |
uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
|
123 |
accept_multiple_files=accept_multiple_files)
|
124 |
if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
|
@@ -129,125 +136,123 @@ class Sections:
|
|
129 |
pil_image = Image.open(uploaded_image)
|
130 |
pil_image = preprocess_image(pil_image)
|
131 |
images.append(pil_image)
|
132 |
-
|
133 |
|
134 |
|
135 |
@staticmethod
|
136 |
-
def image_picker(
|
137 |
-
col1, col2, col3 = st.
|
138 |
with col1:
|
139 |
default_image_1 = load_image_from_url("https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg")
|
140 |
st.image(default_image_1, use_column_width=True)
|
141 |
if st.button("Select image 1"):
|
142 |
-
|
143 |
-
|
144 |
with col2:
|
145 |
default_image_2 = load_image_from_url("https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg")
|
146 |
st.image(default_image_2, use_column_width=True)
|
147 |
if st.button("Select image 2"):
|
148 |
-
|
149 |
-
|
150 |
with col3:
|
151 |
default_image_3 = load_image_from_url("https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg")
|
152 |
st.image(default_image_3, use_column_width=True)
|
153 |
if st.button("Select image 3"):
|
154 |
-
|
155 |
-
|
156 |
|
157 |
@staticmethod
|
158 |
-
def dataset_picker(
|
159 |
-
columns = st.
|
160 |
-
|
161 |
image_idx = 0
|
162 |
for col in columns:
|
163 |
-
col.image(
|
164 |
image_idx += 1
|
165 |
-
col.image(
|
166 |
image_idx += 1
|
167 |
if st.button("Select random dataset"):
|
168 |
-
|
169 |
-
|
170 |
|
171 |
@staticmethod
|
172 |
-
def prompts_input(
|
173 |
raw_text_input = st.text_input(input_label,
|
174 |
-
value=
|
175 |
-
|
176 |
if raw_text_input:
|
177 |
-
|
178 |
|
179 |
@staticmethod
|
180 |
-
def single_image_input_preview(
|
181 |
st.markdown("### Preview")
|
182 |
-
col1, col2 = st.
|
183 |
with col1:
|
184 |
st.markdown("Image to classify")
|
185 |
-
if
|
186 |
-
st.image(
|
187 |
else:
|
188 |
st.warning("Select an image")
|
189 |
|
190 |
with col2:
|
191 |
st.markdown("Labels to choose from")
|
192 |
-
if
|
193 |
-
for prompt in
|
194 |
st.markdown(f"* {prompt}")
|
195 |
-
if len(
|
196 |
st.warning("At least two prompts/classes are needed")
|
197 |
else:
|
198 |
st.warning("Enter the prompts/classes to classify from")
|
199 |
|
200 |
@staticmethod
|
201 |
-
def multiple_images_input_preview(
|
202 |
st.markdown("### Preview")
|
203 |
st.markdown("Images to classify")
|
204 |
-
col1, col2, col3 = st.
|
205 |
-
if
|
206 |
-
for idx, image in enumerate(
|
207 |
-
if idx < len(
|
208 |
-
col1.image(
|
209 |
else:
|
210 |
-
col2.image(
|
211 |
-
if len(
|
212 |
col2.warning("At least 2 images required")
|
213 |
else:
|
214 |
col1.warning("Select an image")
|
215 |
|
216 |
with col3:
|
217 |
st.markdown("Query prompt")
|
218 |
-
if
|
219 |
-
for prompt in
|
220 |
st.write(prompt)
|
221 |
else:
|
222 |
st.warning("Enter the prompt to classify")
|
223 |
|
224 |
@staticmethod
|
225 |
-
def classification_output(
|
226 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
227 |
-
if st.button("Predict") and is_valid_prediction_state(
|
228 |
with st.spinner("Predicting..."):
|
229 |
|
230 |
st.markdown("### Results")
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
scored_prompts = [(prompt, score) for prompt, score in zip(state.prompts, scores)]
|
235 |
-
st.json(scores)
|
236 |
sorted_scored_prompts = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
|
237 |
for prompt, probability in sorted_scored_prompts:
|
238 |
percentage_prob = int(probability * 100)
|
239 |
st.markdown(
|
240 |
-
f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)
|
241 |
-
elif len(
|
242 |
-
st.markdown(f"### {
|
243 |
|
244 |
-
scores = model.
|
245 |
-
scored_images = [(image, score) for image, score in zip(state.images, scores)]
|
246 |
st.json(scores)
|
247 |
-
|
|
|
248 |
|
249 |
for image, probability in sorted_scored_images[:5]:
|
250 |
-
col1, col2 = st.
|
251 |
col1.image(image, use_column_width=True)
|
252 |
percentage_prob = int(probability * 100)
|
253 |
col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
|
@@ -269,47 +274,45 @@ class Sections:
|
|
269 |
|
270 |
|
271 |
Sections.header()
|
272 |
-
col1, col2 = st.
|
273 |
col1.markdown(" "); col1.markdown(" ")
|
274 |
col1.markdown("#### Task selection")
|
275 |
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
276 |
st.markdown("<br>", unsafe_allow_html=True)
|
277 |
-
|
278 |
model = load_model()
|
279 |
-
session_state = get_state()
|
280 |
if task_name == "Image classification":
|
281 |
-
Sections.image_uploader(
|
282 |
-
if session_state.images is None:
|
283 |
st.markdown("or choose one from")
|
284 |
-
Sections.image_picker(
|
285 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
286 |
-
Sections.prompts_input(
|
287 |
-
limit_number_images(
|
288 |
-
Sections.single_image_input_preview(
|
289 |
-
Sections.classification_output(
|
290 |
elif task_name == "Prompt ranking":
|
291 |
-
Sections.image_uploader(
|
292 |
-
if session_state.images is None:
|
293 |
st.markdown("or choose one from")
|
294 |
-
Sections.image_picker(
|
295 |
"A beautiful creature;"
|
296 |
" Something that grows in tropical regions")
|
297 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
298 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
299 |
-
Sections.prompts_input(
|
300 |
-
limit_number_images(
|
301 |
-
Sections.single_image_input_preview(
|
302 |
-
Sections.classification_output(
|
303 |
elif task_name == "Image ranking":
|
304 |
-
Sections.image_uploader(
|
305 |
-
if session_state.images is None or len(session_state.images) < 2:
|
306 |
st.markdown("or use this random dataset")
|
307 |
-
Sections.dataset_picker(
|
308 |
-
Sections.prompts_input(
|
309 |
-
limit_number_prompts(
|
310 |
-
Sections.multiple_images_input_preview(
|
311 |
-
Sections.classification_output(
|
312 |
|
313 |
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
314 |
"", unsafe_allow_html=True)
|
315 |
-
session_state.sync()
|
|
|
4 |
import streamlit as st
|
5 |
from clip_model import ClipModel
|
6 |
|
|
|
|
|
|
|
7 |
from PIL import Image
|
8 |
|
9 |
IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
|
|
|
29 |
def load_model() -> ClipModel:
|
30 |
return ClipModel()
|
31 |
|
32 |
+
def init_state():
|
33 |
+
if "images" not in st.session_state:
|
34 |
+
st.session_state.images = None
|
35 |
+
if "prompts" not in st.session_state:
|
36 |
+
st.session_state.prompts = None
|
37 |
+
if "predictions" not in st.session_state:
|
38 |
+
st.session_state.predictions = None
|
39 |
+
if "default_text_input" not in st.session_state:
|
40 |
+
st.session_state.default_text_input = None
|
41 |
+
|
42 |
|
43 |
+
def limit_number_images():
|
44 |
"""When moving between tasks sometimes the state of images can have too many samples"""
|
45 |
+
if st.session_state.images is not None and len(st.session_state.images) > 1:
|
46 |
+
st.session_state.images = [st.session_state.images[0]]
|
47 |
|
48 |
|
49 |
+
def limit_number_prompts():
|
50 |
"""When moving between tasks sometimes the state of prompts can have too many samples"""
|
51 |
+
if st.session_state.prompts is not None and len(st.session_state.prompts) > 1:
|
52 |
+
st.session_state.prompts = [st.session_state.prompts[0]]
|
53 |
|
54 |
|
55 |
+
def is_valid_prediction_state() -> bool:
|
56 |
+
if st.session_state.images is None or len(st.session_state.images) < 1:
|
57 |
st.error("Choose at least one image before predicting")
|
58 |
return False
|
59 |
+
if st.session_state.prompts is None or len(st.session_state.prompts) < 1:
|
60 |
st.error("Write at least one prompt before predicting")
|
61 |
return False
|
62 |
return True
|
|
|
104 |
st.markdown("### Try OpenAI's CLIP model in your browser")
|
105 |
st.markdown(" ")
|
106 |
st.markdown(" ")
|
107 |
+
with st.expander("What is CLIP?"):
|
108 |
st.markdown("CLIP is a machine learning model that computes similarity between text "
|
109 |
"(also called prompts) and images. It has been trained on a dataset with millions of diverse"
|
110 |
" image-prompt pairs, which allows it to generalize to unseen examples."
|
111 |
" <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
|
112 |
unsafe_allow_html=True)
|
113 |
+
col1, col2 = st.columns(2)
|
114 |
col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
|
115 |
col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
|
116 |
+
with st.expander("What can CLIP do?"):
|
117 |
st.markdown("#### Prompt ranking")
|
118 |
st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
|
119 |
st.markdown("#### Image ranking")
|
|
|
125 |
st.markdown(" ")
|
126 |
|
127 |
@staticmethod
|
128 |
+
def image_uploader(accept_multiple_files: bool):
|
129 |
uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
|
130 |
accept_multiple_files=accept_multiple_files)
|
131 |
if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
|
|
|
136 |
pil_image = Image.open(uploaded_image)
|
137 |
pil_image = preprocess_image(pil_image)
|
138 |
images.append(pil_image)
|
139 |
+
st.session_state.images = images
|
140 |
|
141 |
|
142 |
@staticmethod
|
143 |
+
def image_picker(default_text_input: str):
|
144 |
+
col1, col2, col3 = st.columns(3)
|
145 |
with col1:
|
146 |
default_image_1 = load_image_from_url("https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg")
|
147 |
st.image(default_image_1, use_column_width=True)
|
148 |
if st.button("Select image 1"):
|
149 |
+
st.session_state.images = [default_image_1]
|
150 |
+
st.session_state.default_text_input = default_text_input
|
151 |
with col2:
|
152 |
default_image_2 = load_image_from_url("https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg")
|
153 |
st.image(default_image_2, use_column_width=True)
|
154 |
if st.button("Select image 2"):
|
155 |
+
st.session_state.images = [default_image_2]
|
156 |
+
st.session_state.default_text_input = default_text_input
|
157 |
with col3:
|
158 |
default_image_3 = load_image_from_url("https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg")
|
159 |
st.image(default_image_3, use_column_width=True)
|
160 |
if st.button("Select image 3"):
|
161 |
+
st.session_state.images = [default_image_3]
|
162 |
+
st.session_state.default_text_input = default_text_input
|
163 |
|
164 |
@staticmethod
|
165 |
+
def dataset_picker():
|
166 |
+
columns = st.columns(5)
|
167 |
+
st.session_state.dataset = load_default_dataset()
|
168 |
image_idx = 0
|
169 |
for col in columns:
|
170 |
+
col.image(st.session_state.dataset[image_idx])
|
171 |
image_idx += 1
|
172 |
+
col.image(st.session_state.dataset[image_idx])
|
173 |
image_idx += 1
|
174 |
if st.button("Select random dataset"):
|
175 |
+
st.session_state.images = st.session_state.dataset
|
176 |
+
st.session_state.default_text_input = "A sign that says 'SLOW DOWN'"
|
177 |
|
178 |
@staticmethod
|
179 |
+
def prompts_input(input_label: str, prompt_prefix: str = ''):
|
180 |
raw_text_input = st.text_input(input_label,
|
181 |
+
value=st.session_state.default_text_input if st.session_state.default_text_input is not None else "")
|
182 |
+
st.session_state.is_default_text_input = raw_text_input == st.session_state.default_text_input
|
183 |
if raw_text_input:
|
184 |
+
st.session_state.prompts = [prompt_prefix + class_name for class_name in raw_text_input.split(";") if len(class_name) > 1]
|
185 |
|
186 |
@staticmethod
|
187 |
+
def single_image_input_preview():
|
188 |
st.markdown("### Preview")
|
189 |
+
col1, col2 = st.columns([1, 2])
|
190 |
with col1:
|
191 |
st.markdown("Image to classify")
|
192 |
+
if st.session_state.images is not None:
|
193 |
+
st.image(st.session_state.images[0], use_column_width=True)
|
194 |
else:
|
195 |
st.warning("Select an image")
|
196 |
|
197 |
with col2:
|
198 |
st.markdown("Labels to choose from")
|
199 |
+
if st.session_state.prompts is not None:
|
200 |
+
for prompt in st.session_state.prompts:
|
201 |
st.markdown(f"* {prompt}")
|
202 |
+
if len(st.session_state.prompts) < 2:
|
203 |
st.warning("At least two prompts/classes are needed")
|
204 |
else:
|
205 |
st.warning("Enter the prompts/classes to classify from")
|
206 |
|
207 |
@staticmethod
|
208 |
+
def multiple_images_input_preview():
|
209 |
st.markdown("### Preview")
|
210 |
st.markdown("Images to classify")
|
211 |
+
col1, col2, col3 = st.columns(3)
|
212 |
+
if st.session_state.images is not None:
|
213 |
+
for idx, image in enumerate(st.session_state.images):
|
214 |
+
if idx < len(st.session_state.images) / 2:
|
215 |
+
col1.image(st.session_state.images[idx], use_column_width=True)
|
216 |
else:
|
217 |
+
col2.image(st.session_state.images[idx], use_column_width=True)
|
218 |
+
if len(st.session_state.images) < 2:
|
219 |
col2.warning("At least 2 images required")
|
220 |
else:
|
221 |
col1.warning("Select an image")
|
222 |
|
223 |
with col3:
|
224 |
st.markdown("Query prompt")
|
225 |
+
if st.session_state.prompts is not None:
|
226 |
+
for prompt in st.session_state.prompts:
|
227 |
st.write(prompt)
|
228 |
else:
|
229 |
st.warning("Enter the prompt to classify")
|
230 |
|
231 |
@staticmethod
|
232 |
+
def classification_output(model: ClipModel):
|
233 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
234 |
+
if st.button("Predict") and is_valid_prediction_state(): # PREDICT π
|
235 |
with st.spinner("Predicting..."):
|
236 |
|
237 |
st.markdown("### Results")
|
238 |
+
if len(st.session_state.images) == 1:
|
239 |
+
scores = model.compute_prompts_probabilities(st.session_state.images[0], st.session_state.prompts)
|
240 |
+
scored_prompts = [(prompt, score) for prompt, score in zip(st.session_state.prompts, scores)]
|
|
|
|
|
241 |
sorted_scored_prompts = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
|
242 |
for prompt, probability in sorted_scored_prompts:
|
243 |
percentage_prob = int(probability * 100)
|
244 |
st.markdown(
|
245 |
+
f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) {prompt}")
|
246 |
+
elif len(st.session_state.prompts) == 1:
|
247 |
+
st.markdown(f"### {st.session_state.prompts[0]}")
|
248 |
|
249 |
+
scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
|
|
|
250 |
st.json(scores)
|
251 |
+
scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
|
252 |
+
sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
|
253 |
|
254 |
for image, probability in sorted_scored_images[:5]:
|
255 |
+
col1, col2 = st.columns([1, 3])
|
256 |
col1.image(image, use_column_width=True)
|
257 |
percentage_prob = int(probability * 100)
|
258 |
col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
|
|
|
274 |
|
275 |
|
276 |
Sections.header()
|
277 |
+
col1, col2 = st.columns([1, 2])
|
278 |
col1.markdown(" "); col1.markdown(" ")
|
279 |
col1.markdown("#### Task selection")
|
280 |
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
281 |
st.markdown("<br>", unsafe_allow_html=True)
|
282 |
+
init_state()
|
283 |
model = load_model()
|
|
|
284 |
if task_name == "Image classification":
|
285 |
+
Sections.image_uploader(accept_multiple_files=False)
|
286 |
+
if st.session_state.images is None:
|
287 |
st.markdown("or choose one from")
|
288 |
+
Sections.image_picker(default_text_input="banana; boat; bird")
|
289 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
290 |
+
Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
|
291 |
+
limit_number_images()
|
292 |
+
Sections.single_image_input_preview()
|
293 |
+
Sections.classification_output(model)
|
294 |
elif task_name == "Prompt ranking":
|
295 |
+
Sections.image_uploader(accept_multiple_files=False)
|
296 |
+
if st.session_state.images is None:
|
297 |
st.markdown("or choose one from")
|
298 |
+
Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
|
299 |
"A beautiful creature;"
|
300 |
" Something that grows in tropical regions")
|
301 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
302 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
303 |
+
Sections.prompts_input(input_label)
|
304 |
+
limit_number_images()
|
305 |
+
Sections.single_image_input_preview()
|
306 |
+
Sections.classification_output(model)
|
307 |
elif task_name == "Image ranking":
|
308 |
+
Sections.image_uploader(accept_multiple_files=True)
|
309 |
+
if st.session_state.images is None or len(st.session_state.images) < 2:
|
310 |
st.markdown("or use this random dataset")
|
311 |
+
Sections.dataset_picker()
|
312 |
+
Sections.prompts_input("Enter the prompt to query the images by")
|
313 |
+
limit_number_prompts()
|
314 |
+
Sections.multiple_images_input_preview()
|
315 |
+
Sections.classification_output(model)
|
316 |
|
317 |
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
318 |
"", unsafe_allow_html=True)
|
|
clip_model.py
CHANGED
@@ -9,7 +9,7 @@ class ClipModel:
|
|
9 |
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
|
10 |
'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
|
11 |
"""
|
12 |
-
self.
|
13 |
|
14 |
def predict(self, images: list[Image], prompts: list[str]) -> dict:
|
15 |
if len(images) == 1:
|
@@ -19,43 +19,40 @@ class ClipModel:
|
|
19 |
else:
|
20 |
raise ValueError('Either images or prompts must be a single element')
|
21 |
|
22 |
-
def compute_prompts_probabilities(self, image: Image, prompts: list[str]) ->
|
23 |
-
preprocessed_image = self.
|
24 |
tokenized_prompts = clip.tokenize(prompts)
|
25 |
with torch.inference_mode():
|
26 |
-
image_features = self.
|
27 |
-
text_features = self.
|
28 |
|
29 |
# normalized features
|
30 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
31 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
32 |
|
33 |
# cosine similarity as logits
|
34 |
-
logit_scale = self.
|
35 |
logits_per_image = logit_scale * image_features @ text_features.t()
|
36 |
|
37 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
38 |
|
39 |
-
|
40 |
-
return scored_prompts
|
41 |
|
42 |
-
def compute_images_probabilities(self, images: list[Image], prompt: str) ->
|
43 |
-
|
44 |
-
preprocessed_images = [self.img_preprocess(image).unsqueeze(0) for image in images]
|
45 |
tokenized_prompts = clip.tokenize(prompt)
|
46 |
with torch.inference_mode():
|
47 |
-
image_features = self.
|
48 |
-
text_features = self.
|
49 |
|
50 |
# normalized features
|
51 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
52 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
53 |
|
54 |
# cosine similarity as logits
|
55 |
-
logit_scale = self.
|
56 |
logits_per_image = logit_scale * image_features @ text_features.t()
|
57 |
|
58 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
59 |
|
60 |
-
|
61 |
-
return scored_prompts
|
|
|
9 |
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
|
10 |
'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
|
11 |
"""
|
12 |
+
self._model, self._img_preprocess = clip.load(model_name)
|
13 |
|
14 |
def predict(self, images: list[Image], prompts: list[str]) -> dict:
|
15 |
if len(images) == 1:
|
|
|
19 |
else:
|
20 |
raise ValueError('Either images or prompts must be a single element')
|
21 |
|
22 |
+
def compute_prompts_probabilities(self, image: Image, prompts: list[str]) -> list[float]:
|
23 |
+
preprocessed_image = self._img_preprocess(image).unsqueeze(0)
|
24 |
tokenized_prompts = clip.tokenize(prompts)
|
25 |
with torch.inference_mode():
|
26 |
+
image_features = self._model.encode_image(preprocessed_image)
|
27 |
+
text_features = self._model.encode_text(tokenized_prompts)
|
28 |
|
29 |
# normalized features
|
30 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
31 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
32 |
|
33 |
# cosine similarity as logits
|
34 |
+
logit_scale = self._model.logit_scale.exp()
|
35 |
logits_per_image = logit_scale * image_features @ text_features.t()
|
36 |
|
37 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
38 |
|
39 |
+
return probs
|
|
|
40 |
|
41 |
+
def compute_images_probabilities(self, images: list[Image], prompt: str) -> list[float]:
|
42 |
+
preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
|
|
|
43 |
tokenized_prompts = clip.tokenize(prompt)
|
44 |
with torch.inference_mode():
|
45 |
+
image_features = self._model.encode_image(torch.cat(preprocessed_images))
|
46 |
+
text_features = self._model.encode_text(tokenized_prompts)
|
47 |
|
48 |
# normalized features
|
49 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
50 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
51 |
|
52 |
# cosine similarity as logits
|
53 |
+
logit_scale = self._model.logit_scale.exp()
|
54 |
logits_per_image = logit_scale * image_features @ text_features.t()
|
55 |
|
56 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
57 |
|
58 |
+
return probs
|
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
streamlit~=
|
2 |
git+https://github.com/openai/CLIP@b46f5ac
|
3 |
Pillow==8.1.0
|
4 |
-
mock==4.0.3
|
|
|
|
1 |
+
streamlit~=1.11.1
|
2 |
git+https://github.com/openai/CLIP@b46f5ac
|
3 |
Pillow==8.1.0
|
4 |
+
mock==4.0.3
|
5 |
+
protobuf==3.20.0 # It raises errors otherwise
|