File size: 6,247 Bytes
b0b3c9d
 
84f0d2b
 
b0b3c9d
 
 
 
 
 
 
84f0d2b
b0b3c9d
 
 
 
84f0d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b3c9d
84f0d2b
 
b0b3c9d
84f0d2b
 
 
b0b3c9d
84f0d2b
 
 
 
 
 
 
 
 
b0b3c9d
 
 
 
952c726
b0b3c9d
 
 
 
1c30404
b0b3c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f11963a
 
 
 
b0b3c9d
952c726
b0b3c9d
 
 
 
 
 
bf2f661
952c726
 
 
 
b0b3c9d
 
84f0d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b3c9d
952c726
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import streamlit as st
import os
import pandas as pd
import matplotlib.pyplot as plt
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import ConversationalRetrievalChain


def add_vertical_space(spaces=1):
    for _ in range(spaces):
        st.sidebar.markdown("---")

def plot_histogram(df):
    numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    if numeric_columns:
        selected_column = st.selectbox('Select column for histogram', numeric_columns, key='hist_col')
        plt.figure(figsize=(10, 5))
        plt.hist(df[selected_column], bins=20, alpha=0.75)
        plt.title(f'Distribution of {selected_column}')
        st.pyplot(plt)
    else:
        st.write("No numeric columns available for plotting.")

def plot_scatter(df):
    numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    if numeric_columns and len(numeric_columns) > 1:
        col1 = st.selectbox('Select the first variable', numeric_columns, key='first_col')
        col2 = st.selectbox('Select the second variable', numeric_columns, key='second_col')
        plt.figure(figsize=(10, 5))
        plt.scatter(df[col1], df[col2], alpha=0.5)
        plt.title(f'Scatter Plot of {col1} vs {col2}')
        plt.xlabel(col1)
        plt.ylabel(col2)
        st.pyplot(plt)
    else:
        st.write("Need at least two numeric columns to create a scatter plot.")
        
def plot_line(df):
    numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    if numeric_columns:
        selected_column = st.selectbox('Select column for line plot', numeric_columns, key='line_col')
        plt.figure(figsize=(10, 5))
        plt.plot(df[selected_column], marker='o', linestyle='-')
        plt.title(f'Line Plot of {selected_column}')
        plt.xlabel('Index')
        plt.ylabel(selected_column)
        st.pyplot(plt)
    else:
        st.write("No numeric columns available for a line plot.")

def plot_bar(df):
    categorical_columns = df.select_dtypes(include=['object']).columns.tolist()
    numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    if categorical_columns and numeric_columns:
        category_col = st.selectbox('Select the category column', categorical_columns, key='cat_col')
        numeric_col = st.selectbox('Select the numeric column', numeric_columns, key='num_col')
        data_to_plot = df.groupby(category_col)[numeric_col].sum().sort_values(ascending=False)
        plt.figure(figsize=(10, 5))
        data_to_plot.plot(kind='bar')
        plt.title(f'Bar Chart of {numeric_col} by {category_col}')
        plt.xlabel(category_col)
        plt.ylabel(f'Sum of {numeric_col}')
        st.pyplot(plt)
    else:
        st.write("No suitable columns available for plotting a bar chart.")

def main():
    st.set_page_config(page_title="Falcon 7B CSV Chatbot", layout="wide")
    st.title("Falcon 7B CSV Chatbot")

    st.sidebar.title("Navigation")
    app_mode = st.sidebar.selectbox("Choose the app mode",
                                    ["Chat with Llama-2", "Data Visualization"])

    if app_mode == "Chat with Llama-2":
        run_llama_chatbot()
    elif app_mode == "Data Visualization":
        data_visualization()
        
    st.sidebar.markdown('''The Falcon 7B CSV Chatbot uses the **Falcon-7B-GGML** model.''')
    
def run_llama_chatbot():
    
    DB_FAISS_PATH = "vectorstore/db_faiss"
    TEMP_DIR = "temp"

    if not os.path.exists(TEMP_DIR):
        os.makedirs(TEMP_DIR)

    uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'], help="Upload a CSV file")

    add_vertical_space(1)
    st.sidebar.markdown('Made by Sunirmala Mohanta')

    if uploaded_file is not None:
        file_path = os.path.join(TEMP_DIR, uploaded_file.name)
        with open(file_path, "wb") as f:
            f.write(uploaded_file.getvalue())

        st.write(f"Uploaded file: {uploaded_file.name}")
        st.write("Processing CSV file...")

        loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','})
        data = loader.load()

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
        text_chunks = text_splitter.split_documents(data)

        st.write(f"Total text chunks: {len(text_chunks)}")

        embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
        docsearch = FAISS.from_documents(text_chunks, embeddings)
        docsearch.save_local(DB_FAISS_PATH)

        llm = CTransformers(model="models/falcon-7b-instruct.ggccv1.q4_0.bin",
                            model_type="falcon",
                            max_new_tokens=512,
                            temperature=0.1)

        qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever())

        st.write("### Enter your query:")
        query = st.text_input("Input Prompt:")
        if query:
            with st.spinner("Processing your question..."):
                chat_history = []
                result = qa({"question": query, "chat_history": chat_history})
                st.write("---")
                st.write("### Response:")
                st.write(f"> {result['answer']}")

        os.remove(file_path)


def data_visualization():
    
    uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'])
    if uploaded_file is not None:
        df = pd.read_csv(uploaded_file)
        st.write("Uploaded file preview:")
        st.write(df.head())

        plot_type = st.sidebar.radio("Choose a type of plot:", 
                                    ('Histogram', 'Scatter Plot', 'Line Plot', 'Bar Chart'))

        if plot_type == 'Histogram':
            plot_histogram(df)
        elif plot_type == 'Scatter Plot':
            plot_scatter(df)
        elif plot_type == 'Line Plot':
            plot_line(df)
        elif plot_type == 'Bar Chart':
            plot_bar(df)

if __name__ == "__main__":
    main()