search_demo / app.py
bibliotecadebabel
added cohere back
c36cba9
raw
history blame
5.49 kB
import torch
import src.constants.config as configurations
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from src.constants.credentials import cohere_trial_key, mixedbread_key
import streamlit as st
from src.reader import Reader
from src.utils_search import UtilsSearch
from copy import deepcopy
import numpy as np
import cohere
from mixedbread_ai.client import MixedbreadAI
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset
configurations = configurations.service_mxbai_msc_direct_config
api_key = cohere_trial_key
co = cohere.Client(api_key)
semantic_column_names = configurations["semantic_column_names"]
model = MixedbreadAI(api_key=mixedbread_key)
cross_encoder_name = configurations["cross_encoder_name"]
@st.cache_data
def init():
config = configurations
search_utils = UtilsSearch(config)
reader = Reader(config=config["reader_config"])
df = reader.read()
index = search_utils.dataframe_to_index(df)
return df, index, search_utils
def get_possible_values_for_column(column_name, search_utils, df):
if column_name not in st.session_state:
setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
return getattr(st.session_state, column_name)
# Initialize or retrieve from session state
if 'init_results' not in st.session_state:
st.session_state.init_results = init()
# Now you can access your initialized objects directly from the session state
df, index, search_utils = st.session_state.init_results
# Streamlit app layout
st.title('Search Demo')
# Input fields
query = st.text_input('Enter your search query here')
use_cohere = st.checkbox('Use Cohere', value=True) # Default to checked
programmatic_search_config = deepcopy(configurations['programmatic_search_config'])
dynamic_programmatic_search_config = {
"scalar_columns": [],
"discrete_columns": []
}
for column in programmatic_search_config['scalar_columns']:
# Create number input for scalar values
col_name = column["column_name"]
min_val = float(column["min_value"])
max_val = float(column["max_value"])
user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})
for column in programmatic_search_config['discrete_columns']:
# Create multiselect for discrete values
col_name = column["column_name"]
default_values = column["default_values"]
# Assuming you have a function to fetch possible values for the discrete columns based on the column name
possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application
selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})
programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']
# Search button
if st.button('Search'):
if query: # Checking if a query was entered
df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True)
df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config)
df_filtered = df_filtered.sort_values(by='similarities', ascending=True)
df_filtered = df_filtered[:100].reset_index(drop=True)
if len(df_filtered) == 0:
st.write('No results found')
else:
if use_cohere == False:
records = df_filtered.to_dict(orient='records')
dataset_str = SchemaStringDataset(records, configurations)
documents = [batch["inputs"][:256] for batch in dataset_str]
res = model.reranking(
model=cross_encoder_name,
query=query,
input=documents,
top_k=10,
return_input=False
)
ids = [item.index for item in res.data]
results_df = df_filtered.loc[ids]
else:
df_filtered.fillna(value="", inplace=True)
docs = df_filtered.to_dict('records')
column_names = semantic_column_names
docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
rank_fields = list(docs[0].keys())
results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
rank_fields=rank_fields)
top_ids = [hit.index for hit in results.results]
# Create the DataFrame with the rerank results
results_df = df_filtered.iloc[top_ids].copy()
results_df['rank'] = (np.arange(len(results_df)) + 1)
results_df = search_utils.drop_columns(results_df, programmatic_search_config)
st.write(results_df)
else:
st.write("Please enter a query to search.")