Sunirmala's picture
Update app.py
1c30404 verified
import streamlit as st
import os
import pandas as pd
import matplotlib.pyplot as plt
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import ConversationalRetrievalChain
def add_vertical_space(spaces=1):
for _ in range(spaces):
st.sidebar.markdown("---")
def plot_histogram(df):
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if numeric_columns:
selected_column = st.selectbox('Select column for histogram', numeric_columns, key='hist_col')
plt.figure(figsize=(10, 5))
plt.hist(df[selected_column], bins=20, alpha=0.75)
plt.title(f'Distribution of {selected_column}')
st.pyplot(plt)
else:
st.write("No numeric columns available for plotting.")
def plot_scatter(df):
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if numeric_columns and len(numeric_columns) > 1:
col1 = st.selectbox('Select the first variable', numeric_columns, key='first_col')
col2 = st.selectbox('Select the second variable', numeric_columns, key='second_col')
plt.figure(figsize=(10, 5))
plt.scatter(df[col1], df[col2], alpha=0.5)
plt.title(f'Scatter Plot of {col1} vs {col2}')
plt.xlabel(col1)
plt.ylabel(col2)
st.pyplot(plt)
else:
st.write("Need at least two numeric columns to create a scatter plot.")
def plot_line(df):
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if numeric_columns:
selected_column = st.selectbox('Select column for line plot', numeric_columns, key='line_col')
plt.figure(figsize=(10, 5))
plt.plot(df[selected_column], marker='o', linestyle='-')
plt.title(f'Line Plot of {selected_column}')
plt.xlabel('Index')
plt.ylabel(selected_column)
st.pyplot(plt)
else:
st.write("No numeric columns available for a line plot.")
def plot_bar(df):
categorical_columns = df.select_dtypes(include=['object']).columns.tolist()
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if categorical_columns and numeric_columns:
category_col = st.selectbox('Select the category column', categorical_columns, key='cat_col')
numeric_col = st.selectbox('Select the numeric column', numeric_columns, key='num_col')
data_to_plot = df.groupby(category_col)[numeric_col].sum().sort_values(ascending=False)
plt.figure(figsize=(10, 5))
data_to_plot.plot(kind='bar')
plt.title(f'Bar Chart of {numeric_col} by {category_col}')
plt.xlabel(category_col)
plt.ylabel(f'Sum of {numeric_col}')
st.pyplot(plt)
else:
st.write("No suitable columns available for plotting a bar chart.")
def main():
st.set_page_config(page_title="Falcon 7B CSV Chatbot", layout="wide")
st.title("Falcon 7B CSV Chatbot")
st.sidebar.title("Navigation")
app_mode = st.sidebar.selectbox("Choose the app mode",
["Chat with Llama-2", "Data Visualization"])
if app_mode == "Chat with Llama-2":
run_llama_chatbot()
elif app_mode == "Data Visualization":
data_visualization()
st.sidebar.markdown('''The Falcon 7B CSV Chatbot uses the **Falcon-7B-GGML** model.''')
def run_llama_chatbot():
DB_FAISS_PATH = "vectorstore/db_faiss"
TEMP_DIR = "temp"
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'], help="Upload a CSV file")
add_vertical_space(1)
st.sidebar.markdown('Made by Sunirmala Mohanta')
if uploaded_file is not None:
file_path = os.path.join(TEMP_DIR, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getvalue())
st.write(f"Uploaded file: {uploaded_file.name}")
st.write("Processing CSV file...")
loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','})
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
text_chunks = text_splitter.split_documents(data)
st.write(f"Total text chunks: {len(text_chunks)}")
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
docsearch = FAISS.from_documents(text_chunks, embeddings)
docsearch.save_local(DB_FAISS_PATH)
llm = CTransformers(model="models/falcon-7b-instruct.ggccv1.q4_0.bin",
model_type="falcon",
max_new_tokens=512,
temperature=0.1)
qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever())
st.write("### Enter your query:")
query = st.text_input("Input Prompt:")
if query:
with st.spinner("Processing your question..."):
chat_history = []
result = qa({"question": query, "chat_history": chat_history})
st.write("---")
st.write("### Response:")
st.write(f"> {result['answer']}")
os.remove(file_path)
def data_visualization():
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("Uploaded file preview:")
st.write(df.head())
plot_type = st.sidebar.radio("Choose a type of plot:",
('Histogram', 'Scatter Plot', 'Line Plot', 'Bar Chart'))
if plot_type == 'Histogram':
plot_histogram(df)
elif plot_type == 'Scatter Plot':
plot_scatter(df)
elif plot_type == 'Line Plot':
plot_line(df)
elif plot_type == 'Bar Chart':
plot_bar(df)
if __name__ == "__main__":
main()