Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
|
|
2 |
import streamlit as st
|
3 |
import datasets
|
4 |
import plotly.express as px
|
5 |
-
from
|
6 |
from PIL import Image
|
7 |
import os
|
8 |
from pandas.api.types import (
|
@@ -29,28 +29,30 @@ from datetime import datetime
|
|
29 |
|
30 |
#st.set_page_config(layout="wide")
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
@st.cache_data(show_spinner=True)
|
35 |
-
def download_models():
|
36 |
-
# Directory doesn't exist, download and extract the model
|
37 |
-
subprocess.run(["mkdir", "models"])
|
38 |
-
subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True)
|
39 |
-
subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True)
|
40 |
|
41 |
token_ = st.secrets["token"]
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
@st.cache_data(show_spinner=True)
|
44 |
def load_dataset():
|
45 |
-
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train',
|
46 |
-
dataset.add_faiss_index(column="
|
47 |
dataset.add_faiss_index(column="img_embs")
|
48 |
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
|
49 |
return dataset
|
50 |
|
51 |
@st.cache_data(show_spinner=False)
|
52 |
def load_dataframe(_dataset):
|
53 |
-
dataframe = _dataset.remove_columns(['
|
54 |
# Extract hashtags ith regex and convert to set
|
55 |
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
|
56 |
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
|
@@ -60,16 +62,6 @@ def load_dataframe(_dataset):
|
|
60 |
dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
|
61 |
return dataframe
|
62 |
|
63 |
-
@st.cache_resource(show_spinner=True)
|
64 |
-
def load_img_model():
|
65 |
-
# We use the original clip-ViT-B-32 for encoding images
|
66 |
-
return SentenceTransformer('clip-ViT-B-32')
|
67 |
-
|
68 |
-
@st.cache_resource(show_spinner=True)
|
69 |
-
def load_txt_model():
|
70 |
-
# Our text embedding model is aligned to the img_model and maps 50+
|
71 |
-
# languages to the same vector space
|
72 |
-
return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1')
|
73 |
|
74 |
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
75 |
"""
|
@@ -146,27 +138,62 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
|
146 |
return df
|
147 |
|
148 |
@st.cache_data
|
149 |
-
def get_image_embs(
|
150 |
"""
|
151 |
Get image embeddings
|
152 |
Parameters:
|
|
|
|
|
153 |
uploaded_file (PIL.Image): Uploaded image file
|
154 |
Returns:
|
155 |
img_emb (np.array): Image embeddings
|
156 |
"""
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
return img_emb
|
159 |
|
160 |
@st.cache_data(show_spinner=False)
|
161 |
-
def get_text_embs(text):
|
162 |
"""
|
163 |
Get text embeddings
|
164 |
Parameters:
|
|
|
|
|
165 |
text (str): Text to encode
|
166 |
Returns:
|
167 |
text_emb (np.array): Text embeddings
|
168 |
"""
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
return txt_emb
|
171 |
|
172 |
@st.cache_data
|
@@ -186,7 +213,7 @@ def postprocess_results(scores, samples):
|
|
186 |
samples_df["score"] = samples_df["score"].astype(int)
|
187 |
samples_df.reset_index(inplace=True, drop=True)
|
188 |
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
|
189 |
-
return samples_df.drop(columns=['
|
190 |
|
191 |
@st.cache_data
|
192 |
def text_to_text(text, k=5):
|
@@ -198,8 +225,8 @@ def text_to_text(text, k=5):
|
|
198 |
Returns:
|
199 |
results (list): List of tuples of PIL images and labels/scores
|
200 |
"""
|
201 |
-
text_emb = get_text_embs(text)
|
202 |
-
scores, samples = dataset.get_nearest_examples('
|
203 |
return postprocess_results(scores, samples)
|
204 |
|
205 |
@st.cache_data
|
@@ -212,8 +239,8 @@ def image_to_text(image, k=5):
|
|
212 |
Returns:
|
213 |
results (list): List of tuples of PIL images and labels/scores
|
214 |
"""
|
215 |
-
img_emb = get_image_embs(image.name)
|
216 |
-
scores, samples = dataset.get_nearest_examples('
|
217 |
return postprocess_results(scores, samples)
|
218 |
|
219 |
@st.cache_data
|
@@ -226,7 +253,7 @@ def text_to_image(text, k=5):
|
|
226 |
Returns:
|
227 |
results (list): List of tuples of PIL images and labels/scores
|
228 |
"""
|
229 |
-
text_emb = get_text_embs(text)
|
230 |
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
|
231 |
return postprocess_results(scores, samples)
|
232 |
|
@@ -240,7 +267,7 @@ def image_to_image(image, k=5):
|
|
240 |
Returns:
|
241 |
results (list): List of tuples of PIL images and labels/scores
|
242 |
"""
|
243 |
-
img_emb = get_image_embs(image.name)
|
244 |
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
|
245 |
return postprocess_results(scores, samples)
|
246 |
|
@@ -546,13 +573,12 @@ if not check_password():
|
|
546 |
st.stop()
|
547 |
|
548 |
# Check if the directory exists
|
549 |
-
if not os.path.exists(model_dir):
|
550 |
-
download_models()
|
551 |
|
552 |
dataset = load_dataset()
|
553 |
df = load_dataframe(dataset)
|
554 |
-
|
555 |
-
|
|
|
556 |
|
557 |
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
|
558 |
|
@@ -868,7 +894,7 @@ elif selected_menu_option == "Clustering":
|
|
868 |
|
869 |
st.markdown("### Clustering Results")
|
870 |
if type_embeddings == "Text":
|
871 |
-
embeddings = dataset['
|
872 |
elif type_embeddings == "Image":
|
873 |
embeddings = dataset['img_embs']
|
874 |
|
|
|
2 |
import streamlit as st
|
3 |
import datasets
|
4 |
import plotly.express as px
|
5 |
+
from transformers import AutoProcessor, AutoModel
|
6 |
from PIL import Image
|
7 |
import os
|
8 |
from pandas.api.types import (
|
|
|
29 |
|
30 |
#st.set_page_config(layout="wide")
|
31 |
|
32 |
+
model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
token_ = st.secrets["token"]
|
35 |
|
36 |
+
@st.cache_resource(show_spinner=True)
|
37 |
+
def load_model(model_name):
|
38 |
+
"""
|
39 |
+
Load the model and processor
|
40 |
+
"""
|
41 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
42 |
+
model = AutoModel.from_pretrained(model_name)
|
43 |
+
return processor, model
|
44 |
+
|
45 |
@st.cache_data(show_spinner=True)
|
46 |
def load_dataset():
|
47 |
+
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_)
|
48 |
+
dataset.add_faiss_index(column="text_embs")
|
49 |
dataset.add_faiss_index(column="img_embs")
|
50 |
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
|
51 |
return dataset
|
52 |
|
53 |
@st.cache_data(show_spinner=False)
|
54 |
def load_dataframe(_dataset):
|
55 |
+
dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas()
|
56 |
# Extract hashtags ith regex and convert to set
|
57 |
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
|
58 |
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
|
|
|
62 |
dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
|
63 |
return dataframe
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
67 |
"""
|
|
|
138 |
return df
|
139 |
|
140 |
@st.cache_data
|
141 |
+
def get_image_embs(_processor, _model, uploaded_file):
|
142 |
"""
|
143 |
Get image embeddings
|
144 |
Parameters:
|
145 |
+
processor (transformers.AutoProcessor): Processor for the model
|
146 |
+
model (transformers.AutoModel): Model to use for embeddings
|
147 |
uploaded_file (PIL.Image): Uploaded image file
|
148 |
Returns:
|
149 |
img_emb (np.array): Image embeddings
|
150 |
"""
|
151 |
+
# Load the image from local path
|
152 |
+
image = Image.open(uploaded_file)
|
153 |
+
|
154 |
+
# Process the image
|
155 |
+
inputs = _processor(images=image, return_tensors="pt")
|
156 |
+
|
157 |
+
# Forward pass without gradient calculation
|
158 |
+
outputs = _model.get_image_features(**inputs)
|
159 |
+
|
160 |
+
# Normalize the image embeddings
|
161 |
+
img_embs = outputs / outputs.norm(dim=-1, keepdim=True)
|
162 |
+
|
163 |
+
# Convert to list and add to example
|
164 |
+
img_emb = img_embs.squeeze(0).detach().cpu().numpy()
|
165 |
+
|
166 |
return img_emb
|
167 |
|
168 |
@st.cache_data(show_spinner=False)
|
169 |
+
def get_text_embs(_processor, _model, text):
|
170 |
"""
|
171 |
Get text embeddings
|
172 |
Parameters:
|
173 |
+
processor (transformers.AutoProcessor): Processor for the model
|
174 |
+
model (transformers.AutoModel): Model to use for embeddings
|
175 |
text (str): Text to encode
|
176 |
Returns:
|
177 |
text_emb (np.array): Text embeddings
|
178 |
"""
|
179 |
+
# Process the text with truncation
|
180 |
+
inputs = _processor(
|
181 |
+
text=text,
|
182 |
+
return_tensors="pt",
|
183 |
+
padding="max_length",
|
184 |
+
truncation=True,
|
185 |
+
max_length=77 # CLIP's maximum sequence length
|
186 |
+
)
|
187 |
+
|
188 |
+
# Forward pass without gradient calculation
|
189 |
+
outputs = _model.get_text_features(**inputs)
|
190 |
+
|
191 |
+
# Normalize the text embeddings
|
192 |
+
text_embs = outputs / outputs.norm(dim=-1, keepdim=True)
|
193 |
+
|
194 |
+
# Convert to list and add to example
|
195 |
+
txt_emb = text_embs.squeeze(0).detach().cpu().numpy()
|
196 |
+
|
197 |
return txt_emb
|
198 |
|
199 |
@st.cache_data
|
|
|
213 |
samples_df["score"] = samples_df["score"].astype(int)
|
214 |
samples_df.reset_index(inplace=True, drop=True)
|
215 |
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
|
216 |
+
return samples_df.drop(columns=['text_embs', 'img_embs'])
|
217 |
|
218 |
@st.cache_data
|
219 |
def text_to_text(text, k=5):
|
|
|
225 |
Returns:
|
226 |
results (list): List of tuples of PIL images and labels/scores
|
227 |
"""
|
228 |
+
text_emb = get_text_embs(processor, model, text)
|
229 |
+
scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k)
|
230 |
return postprocess_results(scores, samples)
|
231 |
|
232 |
@st.cache_data
|
|
|
239 |
Returns:
|
240 |
results (list): List of tuples of PIL images and labels/scores
|
241 |
"""
|
242 |
+
img_emb = get_image_embs(processor, model, image.name)
|
243 |
+
scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k)
|
244 |
return postprocess_results(scores, samples)
|
245 |
|
246 |
@st.cache_data
|
|
|
253 |
Returns:
|
254 |
results (list): List of tuples of PIL images and labels/scores
|
255 |
"""
|
256 |
+
text_emb = get_text_embs(processor, model, text)
|
257 |
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
|
258 |
return postprocess_results(scores, samples)
|
259 |
|
|
|
267 |
Returns:
|
268 |
results (list): List of tuples of PIL images and labels/scores
|
269 |
"""
|
270 |
+
img_emb = get_image_embs(processor, model, image.name)
|
271 |
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
|
272 |
return postprocess_results(scores, samples)
|
273 |
|
|
|
573 |
st.stop()
|
574 |
|
575 |
# Check if the directory exists
|
|
|
|
|
576 |
|
577 |
dataset = load_dataset()
|
578 |
df = load_dataframe(dataset)
|
579 |
+
processor, model = load_model(model_name)
|
580 |
+
#image_model = load_img_model()
|
581 |
+
#text_model = load_txt_model()
|
582 |
|
583 |
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
|
584 |
|
|
|
894 |
|
895 |
st.markdown("### Clustering Results")
|
896 |
if type_embeddings == "Text":
|
897 |
+
embeddings = dataset['text_embs']
|
898 |
elif type_embeddings == "Image":
|
899 |
embeddings = dataset['img_embs']
|
900 |
|