rjadr commited on
Commit
4470f0e
·
1 Parent(s): b173e2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -53
app.py CHANGED
@@ -2,7 +2,8 @@ import pandas as pd
2
  import streamlit as st
3
  import datasets
4
  import plotly.express as px
5
- from sentence_transformers import SentenceTransformer, util
 
6
  import os
7
  from pandas.api.types import (
8
  is_categorical_dtype,
@@ -11,6 +12,7 @@ from pandas.api.types import (
11
  is_object_dtype,
12
  )
13
  import subprocess
 
14
 
15
  st.set_page_config(layout="wide")
16
 
@@ -30,12 +32,14 @@ def load_dataset():
30
  dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token)
31
  dataset.add_faiss_index(column="txt_embs")
32
  dataset.add_faiss_index(column="img_embs")
33
- dataset = dataset.remove_columns(['Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
34
  return dataset
35
 
36
  @st.cache_data(show_spinner=False)
37
  def load_dataframe(_dataset):
38
  dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
 
 
39
  return dataframe
40
 
41
  @st.cache_resource(show_spinner=True)
@@ -132,7 +136,7 @@ def get_image_embs(image):
132
  Returns:
133
  img_emb (np.array): Image embeddings
134
  """
135
- img_emb = image_model.encode(image)
136
  return img_emb
137
 
138
  @st.cache_data(show_spinner=False)
@@ -287,71 +291,88 @@ with tab2:
287
 
288
  if selected_tab == "Text to Text":
289
  text_to_text_input = st.text_input("Enter text")
290
- text_to_text_k_top = st.slider("Number of results", 1, 60, 8)
291
  if st.button("Search"):
292
- st.dataframe(
293
- data=text_to_text(text_to_text_input, text_to_text_k_top),
294
- column_config={
295
- "image": st.column_config.ImageColumn(
296
- "Image", help="Instagram image"
297
- ),
298
- "URL": st.column_config.LinkColumn(
299
- "Link", help="Instagram link", width="small"
300
- )
301
- },
302
- hide_index=True,
303
- )
304
-
305
- elif selected_tab == "Text to Image":
306
- text_to_image_input = st.text_input("Enter text")
307
- text_to_image_k_top = st.slider("Number of results", 1, 60, 8)
308
- if st.button("Search"):
309
- st.dataframe(
310
- data=text_to_image(text_to_image_input, text_to_image_k_top),
311
- column_config={
312
  "image": st.column_config.ImageColumn(
313
  "Image", help="Instagram image"
314
  ),
315
  "URL": st.column_config.LinkColumn(
316
  "Link", help="Instagram link", width="small"
317
  )
318
- },
319
- hide_index=True,
320
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  elif selected_tab == "Image to Image":
323
- image_to_image_k_top = st.slider("Number of results", 1, 60, 8)
324
  image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
325
  if st.button("Search"):
326
- st.dataframe(
327
- data=image_to_image(image_to_image_input, image_to_image_k_top),
328
- column_config={
329
- "image": st.column_config.ImageColumn(
330
- "Image", help="Instagram image"
331
- ),
332
- "URL": st.column_config.LinkColumn(
333
- "Link", help="Instagram link", width="small"
334
- )
335
- },
336
- hide_index=True,
337
- )
 
 
 
 
 
338
 
339
  elif selected_tab == "Image to Text":
340
- image_to_text_k_top = st.slider("Number of results", 1, 60, 8)
341
  image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
342
  if st.button("Search"):
343
- st.dataframe(
344
- data=image_to_text(image_to_text_input, image_to_text_k_top),
345
- column_config={
346
- "image": st.column_config.ImageColumn(
347
- "Image", help="Instagram image"
348
- ),
349
- "URL": st.column_config.LinkColumn(
350
- "Link", help="Instagram link", width="small"
351
- )
352
- },
353
- hide_index=True,
354
- )
 
 
 
 
355
 
356
  with tab3:
357
  st.markdown("### Time Series Analysis")
 
2
  import streamlit as st
3
  import datasets
4
  import plotly.express as px
5
+ from sentence_transformers import SentenceTransformer
6
+ from PIL import Image
7
  import os
8
  from pandas.api.types import (
9
  is_categorical_dtype,
 
12
  is_object_dtype,
13
  )
14
  import subprocess
15
+ from tempfile import NamedTemporaryFile
16
 
17
  st.set_page_config(layout="wide")
18
 
 
32
  dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token)
33
  dataset.add_faiss_index(column="txt_embs")
34
  dataset.add_faiss_index(column="img_embs")
35
+ dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
36
  return dataset
37
 
38
  @st.cache_data(show_spinner=False)
39
  def load_dataframe(_dataset):
40
  dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
41
+ # dataframe['Post Created'] = dataframe['Post Created'].dt.tz_convert('UTC')
42
+ dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
43
  return dataframe
44
 
45
  @st.cache_resource(show_spinner=True)
 
136
  Returns:
137
  img_emb (np.array): Image embeddings
138
  """
139
+ img_emb = image_model.encode(Image.open(image))
140
  return img_emb
141
 
142
  @st.cache_data(show_spinner=False)
 
291
 
292
  if selected_tab == "Text to Text":
293
  text_to_text_input = st.text_input("Enter text")
294
+ text_to_text_k_top = st.slider("Number of results", 1, 500, 8)
295
  if st.button("Search"):
296
+ if not text_to_text_input:
297
+ st.warning("Please enter text")
298
+ else:
299
+ st.dataframe(
300
+ data=text_to_text(text_to_text_input, text_to_text_k_top),
301
+ column_config={
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  "image": st.column_config.ImageColumn(
303
  "Image", help="Instagram image"
304
  ),
305
  "URL": st.column_config.LinkColumn(
306
  "Link", help="Instagram link", width="small"
307
  )
308
+ },
309
+ hide_index=True,
310
+ )
311
+
312
+ elif selected_tab == "Text to Image":
313
+ text_to_image_input = st.text_input("Enter text")
314
+ text_to_image_k_top = st.slider("Number of results", 1, 500, 8)
315
+ if st.button("Search"):
316
+ if not text_to_image_input:
317
+ st.warning("Please enter some text")
318
+ else:
319
+ st.dataframe(
320
+ data=text_to_image(text_to_image_input, text_to_image_k_top),
321
+ column_config={
322
+ "image": st.column_config.ImageColumn(
323
+ "Image", help="Instagram image"
324
+ ),
325
+ "URL": st.column_config.LinkColumn(
326
+ "Link", help="Instagram link", width="small"
327
+ )
328
+ },
329
+ hide_index=True,
330
+ )
331
 
332
  elif selected_tab == "Image to Image":
333
+ image_to_image_k_top = st.slider("Number of results", 1, 500, 8)
334
  image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
335
+ temp_file = NamedTemporaryFile(delete=False)
336
  if st.button("Search"):
337
+ if not image_to_image_input:
338
+ st.warning("Please upload an image")
339
+ else:
340
+ temp_file.write(image_to_image_input.getvalue())
341
+
342
+ st.dataframe(
343
+ data=image_to_image(temp_file, image_to_image_k_top),
344
+ column_config={
345
+ "image": st.column_config.ImageColumn(
346
+ "Image", help="Instagram image"
347
+ ),
348
+ "URL": st.column_config.LinkColumn(
349
+ "Link", help="Instagram link", width="small"
350
+ )
351
+ },
352
+ hide_index=True,
353
+ )
354
 
355
  elif selected_tab == "Image to Text":
356
+ image_to_text_k_top = st.slider("Number of results", 1, 500, 8)
357
  image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
358
+ temp_file = NamedTemporaryFile(delete=False)
359
  if st.button("Search"):
360
+ if not image_to_text_input:
361
+ st.warning("Please upload an image")
362
+ else:
363
+ temp_file.write(image_to_text_input.getvalue())
364
+ st.dataframe(
365
+ data=image_to_text(temp_file, image_to_text_k_top),
366
+ column_config={
367
+ "image": st.column_config.ImageColumn(
368
+ "Image", help="Instagram image"
369
+ ),
370
+ "URL": st.column_config.LinkColumn(
371
+ "Link", help="Instagram link", width="small"
372
+ )
373
+ },
374
+ hide_index=True,
375
+ )
376
 
377
  with tab3:
378
  st.markdown("### Time Series Analysis")