import pandas as pd import streamlit as st import datasets import plotly.express as px from transformers import AutoProcessor, AutoModel from PIL import Image import os from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) import subprocess from tempfile import NamedTemporaryFile from itertools import combinations import networkx as nx import plotly.graph_objects as go import colorcet as cc from matplotlib.colors import rgb2hex from sklearn.cluster import KMeans, MiniBatchKMeans from sklearn.decomposition import PCA import hdbscan import umap import numpy as np from bokeh.plotting import figure from bokeh.models import ColumnDataSource from datetime import datetime import re #st.set_page_config(layout="wide") model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" token_ = st.secrets["token"] @st.cache_resource(show_spinner=True) def load_model(model_name): """ Load the model and processor """ processor = AutoProcessor.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) return processor, model @st.cache_data(show_spinner=True) def load_dataset(): dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_) dataset.add_faiss_index(column="text_embs") dataset.add_faiss_index(column="img_embs") dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views']) return dataset @st.cache_data(show_spinner=False) def load_dataframe(_dataset): dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas() # Extract hashtags with regex and convert to set dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1) dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set) # Create a cleaned description column up-front dataframe['description_clean'] = dataframe['Description'].apply(clean_and_truncate_text) # Reorder columns to keep the new column next to the original dataframe = dataframe[['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name']]] return dataframe def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns Args: df (pd.DataFrame): Original dataframe Returns: pd.DataFrame: Filtered dataframe """ modify = st.checkbox("Add filters") if not modify: return df df = df.copy() # Try to convert datetimes into a standard format (datetime, no timezone) for col in df.columns: if is_object_dtype(df[col]): try: df[col] = pd.to_datetime(df[col]) except Exception: pass if is_datetime64_any_dtype(df[col]): df[col] = df[col].dt.tz_localize(None) modification_container = st.container() with modification_container: to_filter_columns = st.multiselect("Filter dataframe on", df.columns) for column in to_filter_columns: left, right = st.columns((1, 20)) left.write("↳") # Treat columns with < 10 unique values as categorical if is_categorical_dtype(df[column]) or df[column].nunique() < 10: user_cat_input = right.multiselect( f"Values for {column}", df[column].unique(), default=list(df[column].unique()), ) df = df[df[column].isin(user_cat_input)] elif is_numeric_dtype(df[column]): _min = float(df[column].min()) _max = float(df[column].max()) step = (_max - _min) / 100 user_num_input = right.slider( f"Values for {column}", _min, _max, (_min, _max), step=step, ) df = df[df[column].between(*user_num_input)] elif is_datetime64_any_dtype(df[column]): user_date_input = right.date_input( f"Values for {column}", value=( df[column].min(), df[column].max(), ), ) if len(user_date_input) == 2: user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input df = df.loc[df[column].between(start_date, end_date)] else: user_text_input = right.text_input( f"Substring or regex in {column}", ) if user_text_input: df = df[df[column].str.contains(user_text_input)] return df @st.cache_data def get_image_embs(_processor, _model, uploaded_file): """ Get image embeddings Parameters: processor (transformers.AutoProcessor): Processor for the model model (transformers.AutoModel): Model to use for embeddings uploaded_file (PIL.Image): Uploaded image file Returns: img_emb (np.array): Image embeddings """ # Load the image from local path image = Image.open(uploaded_file) # Process the image inputs = _processor(images=image, return_tensors="pt") # Forward pass without gradient calculation outputs = _model.get_image_features(**inputs) # Normalize the image embeddings img_embs = outputs / outputs.norm(dim=-1, keepdim=True) # Convert to list and add to example img_emb = img_embs.squeeze(0).detach().cpu().numpy() return img_emb @st.cache_data(show_spinner=False) def get_text_embs(_processor, _model, text): """ Get text embeddings Parameters: processor (transformers.AutoProcessor): Processor for the model model (transformers.AutoModel): Model to use for embeddings text (str): Text to encode Returns: text_emb (np.array): Text embeddings """ # Process the text with truncation inputs = _processor( text=text, return_tensors="pt", padding="max_length", truncation=True, max_length=77 # CLIP's maximum sequence length ) # Forward pass without gradient calculation outputs = _model.get_text_features(**inputs) # Normalize the text embeddings text_embs = outputs / outputs.norm(dim=-1, keepdim=True) # Convert to list and add to example txt_emb = text_embs.squeeze(0).detach().cpu().numpy() return txt_emb @st.cache_data def postprocess_results(scores, samples): """ Postprocess results to tuple of labels and scores Parameters: scores (np.array): Scores samples (datasets.Dataset): Samples Returns: labels (list): List of tuples of PIL images and labels/scores """ samples_df = pd.DataFrame.from_dict(samples) samples_df["score"] = scores samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / ( samples_df["score"].max() - samples_df["score"].min())) * 100 samples_df["score"] = samples_df["score"].astype(int) samples_df.reset_index(inplace=True, drop=True) samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] return samples_df.drop(columns=['text_embs', 'img_embs']) @st.cache_data def text_to_text(text, k=5): """ Text to text Parameters: text (str): Input text k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ text_emb = get_text_embs(processor, model, text) scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def image_to_text(image, k=5): """ Image to text Parameters: image (str): Temp filepath to image k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ img_emb = get_image_embs(processor, model, image.name) scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def text_to_image(text, k=5): """ Text to image Parameters: text (str): Input text k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ text_emb = get_text_embs(processor, model, text) scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def image_to_image(image, k=5): """ Image to image Parameters: image (str): Temp filepath to image k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ img_emb = get_image_embs(processor, model, image.name) scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) return postprocess_results(scores, samples) def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph: """ Computes the backbone of the input graph using the disparity filter algorithm. The algorithm is proposed in: M. A. Serrano, M. Boguna, and A. Vespignani, "Extracting the Multiscale Backbone of Complex Weighted Networks", PNAS, 106(16), pp 6483--6488 (2009). DOI: 10.1073/pnas.0808904106 Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ Parameters ---------- g : NetworkX graph The input graph. weight : str, optional (default='weight') The name of the edge attribute to use as weight. alpha : float, optional (default=0.05) The statistical significance level for the disparity filter (p-value). Returns ------- backbone_graph : NetworkX graph The backbone graph. """ # Create an empty graph for the backbone backbone_graph = nx.Graph() # Iterate over all nodes in the input graph for node in g: # Get the degree of the node (number of edges connected to the node) k_n = len(g[node]) # Only proceed if the node has more than one connection if k_n > 1: # Calculate the sum of weights of edges connected to the node sum_w = sum(g[node][neighbor][weight] for neighbor in g[node]) # Iterate over all neighbors of the node for neighbor in g[node]: # Get the weight of the edge between the node and its neighbor edge_weight = g[node][neighbor][weight] # Calculate the proportion of the total weight that this edge represents pij = float(edge_weight) / sum_w # Perform the disparity filter test. If it passes, the edge is considered significant and is added to the backbone if (1 - pij) ** (k_n - 1) < alpha: backbone_graph.add_edge(node, neighbor, weight=edge_weight) # Return the backbone graph return backbone_graph st.cache_data(show_spinner=True) def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict: """ Assigns a unique color to each community in the input graph. Parameters ---------- G : nx.Graph The input graph. attr : str, optional The node attribute of the community names or indexes (default is 'community'). Returns ------- dict A dictionary mapping each community to a unique color. """ glasbey_colors = cc.glasbey_hv communities_ = set(nx.get_node_attributes(G, attr).values()) return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)} st.cache_data(show_spinner=True) def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list: """ Generates hover text for each node in the input graph. Parameters ---------- G : nx.Graph The input graph. attr : str, optional The node attribute of the community names or indexes (default is 'community'). Returns ------- list A list of strings containing the hover text for each node. """ return [f"Node: {str(node)}
Community: {G.nodes[node][attr] + 1}
# of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()] st.cache_data(show_spinner=True) def calculate_node_sizes(G: nx.Graph) -> list: """ Calculates the size of each node in the input graph based on its degree. Parameters ---------- G : nx.Graph The input graph. Returns ------- list A list of node sizes. """ degrees = dict(G.degree()) max_degree = max(deg for node, deg in degrees.items()) return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()] @st.cache_data(show_spinner=True) def plot_graph(_G: nx.Graph, layout_name: str = "spring", community_names_lookup: dict = None): """ Plots a network graph with communities and a legend, using a choice of pure-Python layouts. Parameters ---------- _G : nx.Graph The input graph with a 'community' attribute on each node. layout_name : str, optional The name of the NetworkX layout algorithm to use. community_names_lookup : dict, optional A dictionary mapping community key (e.g., 'Community 1') to a display name. """ # --- Select the layout algorithm --- if layout_name == "kamada_kawai": # Aesthetically pleasing, can be slow on large graphs. pos = nx.kamada_kawai_layout(_G, dim=3) elif layout_name == "circular": # Fast, simple circle. It's 2D, so we add a Z-coordinate. pos_2d = nx.circular_layout(_G) pos = {node: (*coords, 0) for node, coords in pos_2d.items()} elif layout_name == "spectral": # Good for showing clusters. Also 2D, so we add a Z-coordinate. pos_2d = nx.spectral_layout(_G) pos = {node: (*coords, 0) for node, coords in pos_2d.items()} else: # Default to "spring" # The standard physics-based layout. pos = nx.spring_layout(_G, dim=3, k=0.15, iterations=50, seed=779) # --- Generate colors and traces (this part remains the same) --- communities = sorted(list(set(nx.get_node_attributes(_G, 'community').values()))) community_colors = {comm: color for comm, color in zip(communities, cc.glasbey_hv)} edge_x, edge_y, edge_z = [], [], [] for edge in _G.edges(): x0, y0, z0 = pos[edge[0]] x1, y1, z1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_z.extend([z0, z1, None]) edge_trace = go.Scatter3d( x=edge_x, y=edge_y, z=edge_z, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') data = [edge_trace] for comm_idx in communities: comm_key = f'Community {comm_idx + 1}' comm_name = community_names_lookup.get(comm_key, comm_key) node_x, node_y, node_z, node_text = [], [], [], [] for node in _G.nodes(): if _G.nodes[node]['community'] == comm_idx: x, y, z = pos[node] node_x.append(x) node_y.append(y) node_z.append(z) node_text.append(f"Hashtag: #{node}
Community: {comm_name}") node_trace = go.Scatter3d( x=node_x, y=node_y, z=node_z, mode='markers', name=comm_name, marker=dict( symbol='circle', size=7, color=rgb2hex(community_colors[comm_idx]), line=dict(color='rgb(50,50,50)', width=0.5) ), text=node_text, hoverinfo='text' ) data.append(node_trace) # --- Layout (remains the same) --- layout = go.Layout( title="3D Hashtag Network Graph", showlegend=True, legend=dict(title="Communities", x=1.05, y=0.5), width=1000, height=800, margin=dict(l=0, r=0, b=0, t=40), scene=dict( xaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), yaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), zaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='') ) ) fig = go.Figure(data=data, layout=layout) return fig def clean_and_truncate_text(text, max_length=30): """ Removes hashtags and truncates text to a specified length. Args: text (str): The input string to clean. max_length (int): The maximum length of the output string. Returns: str: The cleaned and truncated string. """ if not isinstance(text, str): return "" # Return empty string for non-string inputs # Use regex to remove hashtags (words starting with #) no_hashtags = re.sub(r'#\w+\s*', '', text).strip() # Truncate the string if it's too long if len(no_hashtags) > max_length: return no_hashtags[:max_length] + '...' else: return no_hashtags @st.cache_data(show_spinner=True) def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', # KMeans & MiniBatchKMeans params n_clusters=5, batch_size=256, max_iter=100, # HDBSCAN params min_cluster_size=5, min_samples=5, # Reducer params n_components=2, n_neighbors=15, min_dist=0.0, random_state=42): """Performs dimensionality reduction and clustering on a set of embeddings. This function chains two steps: first, it reduces the dimensionality of the input embeddings using either PCA or UMAP. Second, it applies a clustering algorithm (KMeans, MiniBatchKMeans, or HDBSCAN) to the reduced-dimensional data to assign a cluster label to each embedding. Args: embeddings (list or np.ndarray): A list or array of high-dimensional embedding vectors. Each element should be a 1D NumPy array. clustering_algo (str, optional): The clustering algorithm to use. Options are 'KMeans', 'MiniBatchKMeans', or 'HDBSCAN'. Defaults to 'KMeans'. dim_reduction (str, optional): The dimensionality reduction method to use. Options are 'PCA' or 'UMAP'. Defaults to 'PCA'. n_clusters (int, optional): The number of clusters to form. Used by KMeans and MiniBatchKMeans. Defaults to 5. batch_size (int, optional): The size of mini-batches for MiniBatchKMeans. Defaults to 256. max_iter (int, optional): The maximum number of iterations for MiniBatchKMeans. Defaults to 100. min_cluster_size (int, optional): The minimum number of samples in a group for it to be considered a cluster. Used by HDBSCAN. Defaults to 5. min_samples (int, optional): The number of samples in a neighborhood for a point to be considered a core point. Used by HDBSCAN. Defaults to 5. n_components (int, optional): The number of dimensions to reduce to. Used by PCA and UMAP. Defaults to 2. n_neighbors (int, optional): The number of neighbors to consider for manifold approximation. Used by UMAP. Defaults to 15. min_dist (float, optional): The effective minimum distance between embedded points. Used by UMAP. Defaults to 0.0. random_state (int, optional): The seed used by the random number generator for reproducibility. Defaults to 42. Returns: tuple: A tuple containing: - np.ndarray: An array of cluster labels assigned to each embedding. - np.ndarray: The reduced-dimensional representation of the embeddings. Raises: ValueError: If an invalid `clustering_algo` or `dim_reduction` method is specified. """ # Stack embeddings into a single NumPy array data_array = np.stack(embeddings) # --- 1. Dimensionality Reduction --- if dim_reduction == 'PCA': reducer = PCA(n_components=n_components, random_state=random_state) elif dim_reduction == 'UMAP': reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state) else: raise ValueError('Invalid dimensionality reduction method') reduced_embeddings = reducer.fit_transform(data_array) # --- 2. Clustering --- if clustering_algo == 'MiniBatchKMeans': # Use the specific parameters for MiniBatchKMeans clusterer = MiniBatchKMeans( n_clusters=n_clusters, random_state=random_state, batch_size=batch_size, max_iter=max_iter, n_init='auto' # Recommended setting ) elif clustering_algo == 'KMeans': clusterer = KMeans(n_clusters=n_clusters, random_state=random_state, n_init='auto') elif clustering_algo == 'HDBSCAN': clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) else: raise ValueError('Invalid clustering algorithm') labels = clusterer.fit_predict(reduced_embeddings) return labels, reduced_embeddings st.title("#ditaduranuncamais Data Explorer") def check_password(): """Returns `True` if user is authenticated, `False` otherwise.""" # If the user is already authenticated, just return True. # This is the most important part: we don't render the password form again. if st.session_state.get("password_correct", False): return True # This part of the code will only run if the user is not yet authenticated. def password_entered(): """Checks whether the password entered is correct.""" if st.session_state.get("password") == st.secrets.get("password"): st.session_state["password_correct"] = True # Don't store the password in session state. del st.session_state["password"] else: st.session_state["password_correct"] = False # Show the password input form. st.text_input( "Password", type="password", on_change=password_entered, key="password" ) # Show an error message if the last attempt was incorrect. # The 'in' check prevents the error from showing on the first load. if "password_correct" in st.session_state and not st.session_state.password_correct: st.error("😕 Password incorrect") # Return False to stop the main app from running. return False if not check_password(): st.stop() # Check if the directory exists dataset = load_dataset() df = load_dataframe(dataset) processor, model = load_model(model_name) #image_model = load_img_model() #text_model = load_txt_model() menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"] st.sidebar.markdown('# Menu') selected_menu_option = st.sidebar.radio("Select a page", menu_options) if selected_menu_option == "Data exploration": st.dataframe( data=filter_dataframe(df), # use_container_width=True, column_config={ "image": st.column_config.ImageColumn( "Image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "Link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_menu_option == "Semantic search": tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] selected_tab = st.sidebar.radio("Select a search type", tabs) if selected_tab == "Text to Text": st.markdown('## Text to text search') text_to_text_input = st.text_input("Enter text") text_to_text_k_top = st.slider("Number of results", 1, 500, 20) if st.button("Search"): if not text_to_text_input: st.warning("Please enter text") else: st.dataframe( data=text_to_text(text_to_text_input, text_to_text_k_top), column_config={ "image": st.column_config.ImageColumn( "Image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "Link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Text to Image": st.markdown('## Text to image search') text_to_image_input = st.text_input("Enter text") text_to_image_k_top = st.slider("Number of results", 1, 500, 20) if st.button("Search"): if not text_to_image_input: st.warning("Please enter some text") else: st.dataframe( data=text_to_image(text_to_image_input, text_to_image_k_top), column_config={ "image": st.column_config.ImageColumn( "Image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "Link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Image to Image": st.markdown('## Image to image search') image_to_image_k_top = st.slider("Number of results", 1, 500, 20) image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) temp_file = NamedTemporaryFile(delete=False) if st.button("Search"): if not image_to_image_input: st.warning("Please upload an image") else: temp_file.write(image_to_image_input.getvalue()) st.dataframe( data=image_to_image(temp_file, image_to_image_k_top), column_config={ "image": st.column_config.ImageColumn( "Image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "Link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Image to Text": st.markdown('## Image to text search') image_to_text_k_top = st.slider("Number of results", 1, 500, 20) image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) temp_file = NamedTemporaryFile(delete=False) if st.button("Search"): if not image_to_text_input: st.warning("Please upload an image") else: temp_file.write(image_to_text_input.getvalue()) st.dataframe( data=image_to_text(temp_file, image_to_text_k_top), column_config={ "image": st.column_config.ImageColumn( "Image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "Link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_menu_option == "Hashtags": st.markdown("### Hashtag Co-occurrence Analysis") st.markdown("This section creates a network of hashtags based on how often they are used together. Use the sidebar to configure the analysis, then click the button to generate the network and identify communities.") # --- Sidebar Configuration (no changes) --- if 'dfx' not in st.session_state: st.session_state.dfx = df.copy() all_hashtags = sorted(list(set(item for sublist in st.session_state.dfx['Hashtags'] for item in sublist))) st.sidebar.markdown('## Hashtag Network Options') hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags) col1, col2 = st.sidebar.columns(2) if col1.button("Remove hashtags"): st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove]) if 'hashtag_results' in st.session_state: del st.session_state.hashtag_results st.rerun() if col2.button("Reset Hashtags"): st.session_state.dfx = df.copy() if 'hashtag_results' in st.session_state: del st.session_state.hashtag_results st.rerun() weight_option = st.sidebar.radio( 'Select weight definition', ('Number of users that use the hashtag pairs', 'Total number of occurrences') ) # --- Main Page Content --- if st.button("Generate Hashtag Network", type="primary"): with st.spinner("Building graph, filtering edges, and detecting communities..."): # (Calculation code remains the same as before...) hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)] hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name']) if weight_option == 'Number of users that use the hashtag pairs': edge_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index() else: edge_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name') edge_df = edge_df.rename(columns={'User Name': 'weight'}) edge_df[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_df['hashtag_pair'].tolist(), index=edge_df.index) edge_list = edge_df[['hashtag1', 'hashtag2', 'weight']] G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight') G_backbone = disparity_filter(G, weight='weight', alpha=0.05) communities = list(nx.community.louvain_communities(G_backbone, weight='weight', seed=1234)) communities.sort(key=len, reverse=True) for i, community in enumerate(communities): for node in community: G_backbone.nodes[node]['community'] = i sorted_community_hashtags = pd.DataFrame([ [h for h, _ in sorted(((h, G.degree(h, weight='weight')) for h in com), key=lambda x: x[1], reverse=True)] for com in communities ]).T sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))] # Initialize the community names dataframe and store it in session state df_community_names = pd.DataFrame( sorted_community_hashtags.columns, columns=['community_names'], index=sorted_community_hashtags.columns ) st.session_state.community_names_df = df_community_names st.session_state.hashtag_results = { "G_backbone": G_backbone, "communities": communities, "sorted_community_hashtags": sorted_community_hashtags, "edge_list": edge_list } st.rerun() # --- Display Results Section --- if 'hashtag_results' in st.session_state: results = st.session_state.hashtag_results G_backbone = results['G_backbone'] communities = results['communities'] sorted_community_hashtags = results['sorted_community_hashtags'] edge_list = results['edge_list'] st.success(f"Network generated! Found **{len(communities)}** communities from **{len(G_backbone.nodes)}** hashtags and **{len(G_backbone.edges)}** connections.") # Define the tabs with the editor in its own tab tab_graph, tab_editor, tab_timeline, tab_lists = st.tabs([ "📊 Network Graph", "📝 Edit Community Names", "🕒 Community Timelines", "📋 Community & Edge Lists" ]) with tab_graph: st.markdown("### Hashtag Network Graph") st.markdown("Nodes represent hashtags, colored by community. The legend uses the names from the 'Edit Community Names' tab.") # Re-introduce the layout selector with safe, pure-Python options layout_options = { "Spring": "spring", "Kamada-Kawai": "kamada_kawai", "Circular": "circular", "Spectral": "spectral" } selected_layout_name = st.selectbox( "Graph Layout Algorithm", options=layout_options.keys() ) # Get the actual function name string layout_alg_str = layout_options[selected_layout_name] # Retrieve edited names from session state community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() # Call the plot function with the chosen layout fig = plot_graph( _G=G_backbone, layout_name=layout_alg_str, community_names_lookup=community_names_lookup ) st.plotly_chart(fig, use_container_width=True) with tab_editor: st.markdown("### Edit Community Names") st.markdown("Change the default community names in the table below. The new names will automatically update the graph legend and the timeline chart.") # The data editor modifies the dataframe in session_state edited_df = st.data_editor( st.session_state.community_names_df, use_container_width=True, num_rows="dynamic" # Allows for adding/removing if needed, though less likely here ) # Persist any changes back to session state st.session_state.community_names_df = edited_df st.download_button( label="Download Community Names as CSV", data=edited_df.to_csv().encode("utf-8"), file_name="community_names.csv", mime="text/csv", ) with tab_timeline: st.markdown("### Community Size Over Time") # Retrieve the latest names from session state for the multiselect options community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=list(community_names_lookup.values())) resample_dict = {'Day': 'D', 'Week': 'W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y'} resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4) community_dict = {node: community_names_lookup.get(f'Community {i+1}') for i, comm_set in enumerate(communities) for node in comm_set} df_communities = st.session_state.dfx.copy() df_communities['Communities'] = df_communities['Hashtags'].apply(lambda tags: list(set(community_dict.get(tag) for tag in tags if tag in community_dict))) df_communities = df_communities.explode('Communities').dropna(subset=['Communities']) df_ts = df_communities.set_index('Post Created') df_community_sizes = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'Communities']).size().unstack(fill_value=0) existing_selected_cols = [col for col in selected_communities if col in df_community_sizes.columns] if existing_selected_cols: st.area_chart(df_community_sizes[existing_selected_cols]) else: st.warning("No data available for the selected communities.") with tab_lists: st.markdown("### Hashtag Communities (by importance)") st.dataframe(sorted_community_hashtags) st.markdown("### Top Edge Pairs (by weight)") st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(100)) elif selected_menu_option == "Clustering": st.markdown("## Clustering of Posts") st.markdown("This section allows you to group posts based on the similarity of their text or image content. Use the sidebar to configure the clustering process, then click 'Run Clustering' to see the results.") # --- Sidebar Configuration (no changes here) --- st.sidebar.markdown("# Clustering Options") st.sidebar.markdown("### Data & Algorithm") type_embeddings = st.sidebar.selectbox("Cluster based on:", ["Image", "Text"]) clustering_algo = st.sidebar.selectbox("Clustering Algorithm:", ["MiniBatchKMeans", "HDBSCAN", "KMeans"]) st.sidebar.info(f"**Tip:** `MiniBatchKMeans` is the fastest for a quick overview.") st.sidebar.markdown("### Algorithm Settings") if clustering_algo in ["KMeans", "MiniBatchKMeans"]: n_clusters = st.sidebar.slider("Number of Clusters (k)", 2, 50, 5, key="n_clusters_slider") if clustering_algo == "MiniBatchKMeans": batch_size = st.sidebar.slider("Batch Size", 32, 1024, 256, 32, help="Number of samples to use in each mini-batch.") max_iter = st.sidebar.slider("Max Iterations", 50, 500, 100, 50, help="Maximum number of iterations.") else: batch_size, max_iter = None, None min_cluster_size, min_samples = None, None elif clustering_algo == "HDBSCAN": min_cluster_size = st.sidebar.slider("Minimum Cluster Size", 2, 200, 15, help="Smallest size for a group to be a cluster.") min_samples = st.sidebar.slider("Minimum Samples", 1, 50, 5, help="Larger values lead to more points being declared as noise.") n_clusters, batch_size, max_iter = None, None, None st.sidebar.markdown("### Dimensionality Reduction") dim_reduction = st.sidebar.selectbox("Reduction Method:", ["PCA", "UMAP"]) st.sidebar.info(f"**Tip:** `PCA` is much faster than `UMAP`.") if dim_reduction == "UMAP": n_components = st.sidebar.slider("Number of Components", 2, 80, 50, help="Dimensions to reduce to before clustering.") n_neighbors = st.sidebar.slider("Number of Neighbors", 2, 50, 15, help="Controls UMAP's balance of local/global structure.") min_dist = st.sidebar.slider("Minimum Distance", 0.0, 1.0, 0.0, help="Controls how tightly UMAP packs points.") else: n_components = st.sidebar.slider("Number of Components", 2, 80, 2) n_neighbors, min_dist = None, None # --- Main Page Content --- # 1. Add a button to trigger the expensive computation if st.button("Run Clustering", type="primary"): with st.spinner("Clustering embeddings... This may take a moment."): if type_embeddings == "Text": embeddings = dataset['text_embs'] else: # Image embeddings = dataset['img_embs'] # Call the expensive function here labels, reduced_embeddings = cluster_embeddings( embeddings, clustering_algo=clustering_algo, dim_reduction=dim_reduction, n_clusters=n_clusters, min_cluster_size=min_cluster_size, n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, min_samples=min_samples, batch_size=batch_size, max_iter=max_iter ) # 2. Store the results in session state st.session_state['cluster_results'] = { "labels": labels, "reduced_embeddings": reduced_embeddings, "type_embeddings": type_embeddings, "clustering_algo": clustering_algo, "dim_reduction": dim_reduction } st.rerun() # Rerun to display results immediately after calculation # 3. Only show results if they exist in session state if 'cluster_results' in st.session_state: # Unpack results from session state results = st.session_state['cluster_results'] labels = results['labels'] reduced_embeddings = results['reduced_embeddings'] num_found_clusters = len(set(labels) - {-1}) st.success(f"Clustering complete! Found **{num_found_clusters}** clusters using **{results['clustering_algo']}** on **{results['type_embeddings']}** embeddings with **{results['dim_reduction']}** reduction.") df_clustered = df.copy() df_clustered['cluster'] = labels # 4. Use tabs to organize the output tab1, tab2, tab3 = st.tabs(["📊 Results Table", "📈 2D Visualization", "🕒 Time Series Analysis"]) with tab1: st.markdown("### Clustered Data") st.dataframe( data=filter_dataframe(df_clustered), column_config={ "image": st.column_config.ImageColumn("Image", help="Instagram image"), "URL": st.column_config.LinkColumn("Link", help="Instagram link", width="small") }, hide_index=True, use_container_width=True ) st.download_button( "Download Clustered Data as CSV", df_clustered.to_csv(index=False).encode('utf-8'), f'clustered_data_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv', "text/csv", key='download-cluster-csv' ) with tab2: st.markdown("### Cluster Visualization") if reduced_embeddings.shape[1] > 2: with st.spinner("Reducing dimensions for 2D visualization..."): vis_reducer = umap.UMAP(n_components=2, random_state=42) vis_embeddings = vis_reducer.fit_transform(reduced_embeddings) else: vis_embeddings = reduced_embeddings df_plot_bokeh = pd.DataFrame(vis_embeddings, columns=('x', 'y')) df_plot_bokeh['description_clean'] = df_clustered['description_clean'] df_plot_bokeh['image_url'] = df_clustered['image'] df_plot_bokeh['cluster'] = labels unique_labels = sorted(list(set(labels))) color_dict = {label: rgb2hex(cc.glasbey_hv[i % len(cc.glasbey_hv)]) for i, label in enumerate(unique_labels)} df_plot_bokeh['color'] = df_plot_bokeh['cluster'].map(color_dict) source = ColumnDataSource(data=df_plot_bokeh) TOOLTIPS = """

Cluster: @cluster
@description_clean
""" p = figure(width=800, height=800, tooltips=TOOLTIPS, title="2D Visualization of Post Clusters") p.circle('x', 'y', size=10, source=source, color='color', legend_field='cluster', line_color=None, alpha=0.8) p.legend.title = 'Cluster' p.legend.location = "top_left" st.bokeh_chart(p, use_container_width=True) with tab3: st.markdown("### Cluster Analysis Over Time") # Define the dictionary before using it. resample_dict = { 'Day': 'D', 'Week': 'W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y' } variable = st.selectbox('Select Variable for Time Series:', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions'], key="cluster_ts_var") resample_time = st.selectbox('Resample Time By:', list(resample_dict.keys()), index=2, key="cluster_ts_resample") df_ts = df_clustered.copy() df_ts['Post Created'] = pd.to_datetime(df_ts['Post Created']) df_ts = df_ts.set_index('Post Created') df_ts = df_ts[df_ts['cluster'] != -1] # Exclude noise points if not df_ts.empty: # Use the dictionary to get the correct frequency string ('D', 'W', 'M', etc.) df_plot = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'cluster'])[variable].sum().unstack(fill_value=0) st.line_chart(df_plot) else: st.warning("No data available for plotting (all points may have been classified as noise).") elif selected_menu_option == "Stats": st.markdown("### Time Series Analysis") # Dropdown to select variables variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']) # Dropdown to select time resampling resample_dict = { 'Day': 'D', 'Three Days': '3D', 'Week': 'W', 'Two Weeks': '2W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y' } # Dropdown to select time resampling resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) df_filtered = df.set_index('Post Created') # Slider for date range selection min_date = df_filtered.index.min().date() max_date = df_filtered.index.max().date() date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) # Filter dataframe based on selected date range df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] # Create a separate DataFrame for resampling and plotting df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum() st.line_chart(df_resampled) st.markdown("### Correlation Analysis") # Dropdown to select variables for scatter plot options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'] scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options) # options.remove(scatter_variable_1) # remove the chosen option from the list scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options) # Plot scatter chart st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") # Plot scatter chart scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) #, trendline='ols', trendline_color_override='red') st.plotly_chart(scatter_fig) # calculate correlation for scatter_variable_1 with scatter_variable_2 corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2]) if corr > 0.7: st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.") elif corr > 0.3: st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.") elif corr > -0.3: st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.") elif corr > -0.7: st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.") else: st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.")