rjadr commited on
Commit
a651226
·
verified ·
1 Parent(s): 42da8d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -38
app.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  import streamlit as st
3
  import datasets
4
  import plotly.express as px
5
- from sentence_transformers import SentenceTransformer
6
  from PIL import Image
7
  import os
8
  from pandas.api.types import (
@@ -29,28 +29,30 @@ from datetime import datetime
29
 
30
  #st.set_page_config(layout="wide")
31
 
32
- model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
33
-
34
- @st.cache_data(show_spinner=True)
35
- def download_models():
36
- # Directory doesn't exist, download and extract the model
37
- subprocess.run(["mkdir", "models"])
38
- subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True)
39
- subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True)
40
 
41
  token_ = st.secrets["token"]
42
 
 
 
 
 
 
 
 
 
 
43
  @st.cache_data(show_spinner=True)
44
  def load_dataset():
45
- dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token_)
46
- dataset.add_faiss_index(column="txt_embs")
47
  dataset.add_faiss_index(column="img_embs")
48
  dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
49
  return dataset
50
 
51
  @st.cache_data(show_spinner=False)
52
  def load_dataframe(_dataset):
53
- dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
54
  # Extract hashtags ith regex and convert to set
55
  dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
56
  dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
@@ -60,16 +62,6 @@ def load_dataframe(_dataset):
60
  dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
61
  return dataframe
62
 
63
- @st.cache_resource(show_spinner=True)
64
- def load_img_model():
65
- # We use the original clip-ViT-B-32 for encoding images
66
- return SentenceTransformer('clip-ViT-B-32')
67
-
68
- @st.cache_resource(show_spinner=True)
69
- def load_txt_model():
70
- # Our text embedding model is aligned to the img_model and maps 50+
71
- # languages to the same vector space
72
- return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1')
73
 
74
  def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
75
  """
@@ -146,27 +138,62 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
146
  return df
147
 
148
  @st.cache_data
149
- def get_image_embs(image):
150
  """
151
  Get image embeddings
152
  Parameters:
 
 
153
  uploaded_file (PIL.Image): Uploaded image file
154
  Returns:
155
  img_emb (np.array): Image embeddings
156
  """
157
- img_emb = image_model.encode(Image.open(image))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  return img_emb
159
 
160
  @st.cache_data(show_spinner=False)
161
- def get_text_embs(text):
162
  """
163
  Get text embeddings
164
  Parameters:
 
 
165
  text (str): Text to encode
166
  Returns:
167
  text_emb (np.array): Text embeddings
168
  """
169
- txt_emb = text_model.encode(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  return txt_emb
171
 
172
  @st.cache_data
@@ -186,7 +213,7 @@ def postprocess_results(scores, samples):
186
  samples_df["score"] = samples_df["score"].astype(int)
187
  samples_df.reset_index(inplace=True, drop=True)
188
  samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
189
- return samples_df.drop(columns=['txt_embs', 'img_embs'])
190
 
191
  @st.cache_data
192
  def text_to_text(text, k=5):
@@ -198,8 +225,8 @@ def text_to_text(text, k=5):
198
  Returns:
199
  results (list): List of tuples of PIL images and labels/scores
200
  """
201
- text_emb = get_text_embs(text)
202
- scores, samples = dataset.get_nearest_examples('txt_embs', text_emb, k=k)
203
  return postprocess_results(scores, samples)
204
 
205
  @st.cache_data
@@ -212,8 +239,8 @@ def image_to_text(image, k=5):
212
  Returns:
213
  results (list): List of tuples of PIL images and labels/scores
214
  """
215
- img_emb = get_image_embs(image.name)
216
- scores, samples = dataset.get_nearest_examples('txt_embs', img_emb, k=k)
217
  return postprocess_results(scores, samples)
218
 
219
  @st.cache_data
@@ -226,7 +253,7 @@ def text_to_image(text, k=5):
226
  Returns:
227
  results (list): List of tuples of PIL images and labels/scores
228
  """
229
- text_emb = get_text_embs(text)
230
  scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
231
  return postprocess_results(scores, samples)
232
 
@@ -240,7 +267,7 @@ def image_to_image(image, k=5):
240
  Returns:
241
  results (list): List of tuples of PIL images and labels/scores
242
  """
243
- img_emb = get_image_embs(image.name)
244
  scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
245
  return postprocess_results(scores, samples)
246
 
@@ -546,13 +573,12 @@ if not check_password():
546
  st.stop()
547
 
548
  # Check if the directory exists
549
- if not os.path.exists(model_dir):
550
- download_models()
551
 
552
  dataset = load_dataset()
553
  df = load_dataframe(dataset)
554
- image_model = load_img_model()
555
- text_model = load_txt_model()
 
556
 
557
  menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
558
 
@@ -868,7 +894,7 @@ elif selected_menu_option == "Clustering":
868
 
869
  st.markdown("### Clustering Results")
870
  if type_embeddings == "Text":
871
- embeddings = dataset['txt_embs']
872
  elif type_embeddings == "Image":
873
  embeddings = dataset['img_embs']
874
 
 
2
  import streamlit as st
3
  import datasets
4
  import plotly.express as px
5
+ from transformers import AutoProcessor, AutoModel
6
  from PIL import Image
7
  import os
8
  from pandas.api.types import (
 
29
 
30
  #st.set_page_config(layout="wide")
31
 
32
+ model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
 
 
 
 
 
 
 
33
 
34
  token_ = st.secrets["token"]
35
 
36
+ @st.cache_resource(show_spinner=True)
37
+ def load_model(model_name):
38
+ """
39
+ Load the model and processor
40
+ """
41
+ processor = AutoProcessor.from_pretrained(model_name)
42
+ model = AutoModel.from_pretrained(model_name)
43
+ return processor, model
44
+
45
  @st.cache_data(show_spinner=True)
46
  def load_dataset():
47
+ dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_)
48
+ dataset.add_faiss_index(column="text_embs")
49
  dataset.add_faiss_index(column="img_embs")
50
  dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
51
  return dataset
52
 
53
  @st.cache_data(show_spinner=False)
54
  def load_dataframe(_dataset):
55
+ dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas()
56
  # Extract hashtags ith regex and convert to set
57
  dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
58
  dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
 
62
  dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
63
  return dataframe
64
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
67
  """
 
138
  return df
139
 
140
  @st.cache_data
141
+ def get_image_embs(_processor, _model, uploaded_file):
142
  """
143
  Get image embeddings
144
  Parameters:
145
+ processor (transformers.AutoProcessor): Processor for the model
146
+ model (transformers.AutoModel): Model to use for embeddings
147
  uploaded_file (PIL.Image): Uploaded image file
148
  Returns:
149
  img_emb (np.array): Image embeddings
150
  """
151
+ # Load the image from local path
152
+ image = Image.open(uploaded_file)
153
+
154
+ # Process the image
155
+ inputs = _processor(images=image, return_tensors="pt")
156
+
157
+ # Forward pass without gradient calculation
158
+ outputs = _model.get_image_features(**inputs)
159
+
160
+ # Normalize the image embeddings
161
+ img_embs = outputs / outputs.norm(dim=-1, keepdim=True)
162
+
163
+ # Convert to list and add to example
164
+ img_emb = img_embs.squeeze(0).detach().cpu().numpy()
165
+
166
  return img_emb
167
 
168
  @st.cache_data(show_spinner=False)
169
+ def get_text_embs(_processor, _model, text):
170
  """
171
  Get text embeddings
172
  Parameters:
173
+ processor (transformers.AutoProcessor): Processor for the model
174
+ model (transformers.AutoModel): Model to use for embeddings
175
  text (str): Text to encode
176
  Returns:
177
  text_emb (np.array): Text embeddings
178
  """
179
+ # Process the text with truncation
180
+ inputs = _processor(
181
+ text=text,
182
+ return_tensors="pt",
183
+ padding="max_length",
184
+ truncation=True,
185
+ max_length=77 # CLIP's maximum sequence length
186
+ )
187
+
188
+ # Forward pass without gradient calculation
189
+ outputs = _model.get_text_features(**inputs)
190
+
191
+ # Normalize the text embeddings
192
+ text_embs = outputs / outputs.norm(dim=-1, keepdim=True)
193
+
194
+ # Convert to list and add to example
195
+ txt_emb = text_embs.squeeze(0).detach().cpu().numpy()
196
+
197
  return txt_emb
198
 
199
  @st.cache_data
 
213
  samples_df["score"] = samples_df["score"].astype(int)
214
  samples_df.reset_index(inplace=True, drop=True)
215
  samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
216
+ return samples_df.drop(columns=['text_embs', 'img_embs'])
217
 
218
  @st.cache_data
219
  def text_to_text(text, k=5):
 
225
  Returns:
226
  results (list): List of tuples of PIL images and labels/scores
227
  """
228
+ text_emb = get_text_embs(processor, model, text)
229
+ scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k)
230
  return postprocess_results(scores, samples)
231
 
232
  @st.cache_data
 
239
  Returns:
240
  results (list): List of tuples of PIL images and labels/scores
241
  """
242
+ img_emb = get_image_embs(processor, model, image.name)
243
+ scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k)
244
  return postprocess_results(scores, samples)
245
 
246
  @st.cache_data
 
253
  Returns:
254
  results (list): List of tuples of PIL images and labels/scores
255
  """
256
+ text_emb = get_text_embs(processor, model, text)
257
  scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
258
  return postprocess_results(scores, samples)
259
 
 
267
  Returns:
268
  results (list): List of tuples of PIL images and labels/scores
269
  """
270
+ img_emb = get_image_embs(processor, model, image.name)
271
  scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
272
  return postprocess_results(scores, samples)
273
 
 
573
  st.stop()
574
 
575
  # Check if the directory exists
 
 
576
 
577
  dataset = load_dataset()
578
  df = load_dataframe(dataset)
579
+ processor, model = load_model(model_name)
580
+ #image_model = load_img_model()
581
+ #text_model = load_txt_model()
582
 
583
  menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
584
 
 
894
 
895
  st.markdown("### Clustering Results")
896
  if type_embeddings == "Text":
897
+ embeddings = dataset['text_embs']
898
  elif type_embeddings == "Image":
899
  embeddings = dataset['img_embs']
900