david-oplatka commited on
Commit
b6fadc7
1 Parent(s): 4afa25e

Add Assistant Files

Browse files
Files changed (8) hide show
  1. .gitignore +143 -0
  2. Dockerfile +25 -0
  3. Vectara-logo.png +0 -0
  4. agent.py +89 -0
  5. app.py +204 -0
  6. create_table.sql +15 -0
  7. requirements.txt +10 -0
  8. utils.py +74 -0
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ crawlers/__pycache__/
6
+ core/__pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
+ __pypackages__/
100
+
101
+ # Celery stuff
102
+ celerybeat-schedule
103
+ celerybeat.pid
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .env*
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # secrets file in TOML format
119
+ secrets.toml
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # project file
140
+ project.yaml
141
+
142
+ .idea/
143
+ ev_database.db
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY ./requirements.txt /app/requirements.txt
6
+
7
+ RUN pip3 install --no-cache-dir -r /app/requirements.txt
8
+
9
+ # User
10
+ RUN useradd -m -u 1000 user
11
+ USER user
12
+ ENV HOME /home/user
13
+ ENV PATH $HOME/.local/bin:$PATH
14
+
15
+ WORKDIR $HOME
16
+ RUN mkdir app
17
+ WORKDIR $HOME/app
18
+ COPY . $HOME/app
19
+
20
+ EXPOSE 8501
21
+ CMD streamlit run app.py \
22
+ --server.headless true \
23
+ --server.enableCORS false \
24
+ --server.enableXsrfProtection false \
25
+ --server.fileWatcherType none
Vectara-logo.png ADDED
agent.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ from pydantic import Field, BaseModel
4
+ from omegaconf import OmegaConf
5
+
6
+ from llama_index.core.utilities.sql_wrapper import SQLDatabase
7
+ from sqlalchemy import create_engine, text
8
+
9
+ from dotenv import load_dotenv
10
+ load_dotenv(override=True)
11
+
12
+ from vectara_agentic.agent import Agent
13
+ from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
14
+
15
+ def create_assistant_tools(cfg):
16
+
17
+ class QueryCFPBComplaints(BaseModel):
18
+ query: str = Field(description="The user query.")
19
+
20
+ vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_keys,
21
+ vectara_customer_id=cfg.customer_id,
22
+ vectara_corpus_id=cfg.corpus_ids)
23
+
24
+ summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
25
+
26
+ ask_complaints = vec_factory.create_rag_tool(
27
+ tool_name = "ask_complaints",
28
+ tool_description = """
29
+ Given a user query,
30
+ returns a response to a user question about customer complaints about bank services.
31
+ """,
32
+ tool_args_schema = QueryCFPBComplaints,
33
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
34
+ n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
35
+ summary_num_results = 5,
36
+ vectara_summarizer = summarizer,
37
+ include_citations = False,
38
+ )
39
+
40
+ tools_factory = ToolsFactory()
41
+
42
+ db_tools = tools_factory.database_tools(
43
+ tool_name_prefix = "cfpb",
44
+ content_description = 'Customer complaints about five banks (Bank of America, Wells Fargo, Capital One, Chase, and CITI Bank)',
45
+ sql_database = SQLDatabase(create_engine('sqlite:///cfpb_database.db')),
46
+ )
47
+
48
+ return (tools_factory.standard_tools() +
49
+ tools_factory.guardrail_tools() +
50
+ db_tools +
51
+ [ask_complaints]
52
+ )
53
+
54
+ def initialize_agent(_cfg, update_func=None):
55
+ cfpb_complaints_bot_instructions = """
56
+ - You are a helpful research assistant, with expertise in complaints from the Consumer Financial Protection Bureau, in conversation with a user.
57
+ - Before answering any user query, use cfpb_describe_tables to understand schema of each table, and use get_sample_data
58
+ to get sample data from each table in the database, so that you can understand NULL and unique values for each column.
59
+ - For a query with multiple sub-questions, break down the query into the sub-questions,
60
+ and make separate calls to the ask_complaints tool to answer each sub-question,
61
+ then combine the answers to provide a complete response.
62
+ - Use the database tools (cfpb_load_data, cfpb_describe_tables and cfpb_list_tables) to answer analytical queries.
63
+ - IMPORTANT: When using database_tools, always call the ev_load_sample_data tool with the table you want to query
64
+ to understand the table structure, column naming, and values in the table. Never call the cfpb_load_data tool for a query until you have called cfpb_load_sample_data.
65
+ - When providing links, try to put the name of the website or source of information for the displayed text. Don't just say 'Source'.
66
+ - Never discuss politics, and always respond politely.
67
+ """
68
+
69
+ agent = Agent(
70
+ tools=create_assistant_tools(_cfg),
71
+ topic="Customer complaints from the Consumer Financial Protection Bureau (CFPB)",
72
+ custom_instructions=cfpb_complaints_bot_instructions,
73
+ update_func=update_func
74
+ )
75
+ agent.report()
76
+ return agent
77
+
78
+
79
+ def get_agent_config() -> OmegaConf:
80
+ cfg = OmegaConf.create({
81
+ 'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
82
+ 'corpus_ids': str(os.environ['VECTARA_CORPUS_IDS']),
83
+ 'api_keys': str(os.environ['VECTARA_API_KEYS']),
84
+ 'examples': os.environ.get('QUERY_EXAMPLES', None),
85
+ 'demo_name': "cfpb-assistant",
86
+ 'demo_welcome': "Welcome to the CFPB Customer Complaints demo.",
87
+ 'demo_description': "This assistant can help you gain insights into customer complaints to banks recorded by the Consumer Financial Protection Bureau.",
88
+ })
89
+ return cfg
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import sys
3
+ import os
4
+ import uuid
5
+
6
+ import streamlit as st
7
+ from streamlit_pills import pills
8
+ from streamlit_feedback import streamlit_feedback
9
+
10
+ import nest_asyncio
11
+ import asyncio
12
+
13
+ from utils import thumbs_feedback, escape_dollars_outside_latex, send_amplitude_data
14
+
15
+ import sqlite3
16
+ from datasets import load_dataset
17
+
18
+ from vectara_agentic.agent import AgentStatusType
19
+ from agent import initialize_agent, get_agent_config
20
+
21
+
22
+ initial_prompt = "How can I help you today?"
23
+
24
+ # Setup for HTTP API Calls to Amplitude Analytics
25
+ if 'device_id' not in st.session_state:
26
+ st.session_state.device_id = str(uuid.uuid4())
27
+
28
+
29
+ if "feedback_key" not in st.session_state:
30
+ st.session_state.feedback_key = 0
31
+
32
+ def toggle_logs():
33
+ st.session_state.show_logs = not st.session_state.show_logs
34
+
35
+ def show_example_questions():
36
+ if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
37
+ selected_example = pills("Queries to Try:", st.session_state.example_messages, index=None)
38
+ if selected_example:
39
+ st.session_state.ex_prompt = selected_example
40
+ st.session_state.first_turn = False
41
+ return True
42
+ return False
43
+
44
+ def update_func(status_type: AgentStatusType, msg: str):
45
+ if status_type != AgentStatusType.AGENT_UPDATE:
46
+ output = f"{status_type.value} - {msg}"
47
+ st.session_state.log_messages.append(output)
48
+
49
+ async def launch_bot():
50
+ def reset():
51
+ st.session_state.messages = [{"role": "assistant", "content": initial_prompt, "avatar": "🦖"}]
52
+ st.session_state.thinking_message = "Agent at work..."
53
+ st.session_state.log_messages = []
54
+ st.session_state.prompt = None
55
+ st.session_state.ex_prompt = None
56
+ st.session_state.first_turn = True
57
+ st.session_state.show_logs = False
58
+ if 'agent' not in st.session_state:
59
+ st.session_state.agent = initialize_agent(cfg, update_func=update_func)
60
+
61
+ if 'cfg' not in st.session_state:
62
+ cfg = get_agent_config()
63
+ st.session_state.cfg = cfg
64
+ st.session_state.ex_prompt = None
65
+ example_messages = [example.strip() for example in cfg.examples.split(";")] if cfg.examples else []
66
+ st.session_state.example_messages = [em for em in example_messages if len(em)>0]
67
+ reset()
68
+
69
+ cfg = st.session_state.cfg
70
+
71
+ # left side content
72
+ with st.sidebar:
73
+ image = Image.open('Vectara-logo.png')
74
+ st.image(image, width=175)
75
+ st.markdown(f"## {cfg['demo_welcome']}")
76
+ st.markdown(f"{cfg['demo_description']}")
77
+
78
+ st.markdown("\n\n")
79
+ bc1, _ = st.columns([1, 1])
80
+ with bc1:
81
+ if st.button('Start Over'):
82
+ reset()
83
+ st.rerun()
84
+
85
+ st.divider()
86
+ st.markdown(
87
+ "## How this works?\n"
88
+ "This app was built with [Vectara](https://vectara.com).\n\n"
89
+ "It demonstrates the use of Agentic RAG functionality with Vectara"
90
+ )
91
+
92
+ if "messages" not in st.session_state.keys():
93
+ reset()
94
+
95
+ # Display chat messages
96
+ for message in st.session_state.messages:
97
+ with st.chat_message(message["role"], avatar=message["avatar"]):
98
+ st.write(message["content"])
99
+
100
+ example_container = st.empty()
101
+ with example_container:
102
+ if show_example_questions():
103
+ example_container.empty()
104
+ st.session_state.first_turn = False
105
+ st.rerun()
106
+
107
+ # User-provided prompt
108
+ if st.session_state.ex_prompt:
109
+ prompt = st.session_state.ex_prompt
110
+ else:
111
+ prompt = st.chat_input()
112
+ if prompt:
113
+ st.session_state.messages.append({"role": "user", "content": prompt, "avatar": '🧑‍💻'})
114
+ st.session_state.prompt = prompt # Save the prompt in session state
115
+ st.session_state.log_messages = []
116
+ st.session_state.show_logs = False
117
+ with st.chat_message("user", avatar='🧑‍💻'):
118
+ print(f"Starting new question: {prompt}\n")
119
+ st.write(prompt)
120
+ st.session_state.ex_prompt = None
121
+
122
+ # Generate a new response if last message is not from assistant
123
+ if st.session_state.prompt:
124
+ with st.chat_message("assistant", avatar='🤖'):
125
+ with st.spinner(st.session_state.thinking_message):
126
+ res = st.session_state.agent.chat(st.session_state.prompt)
127
+ res = escape_dollars_outside_latex(res)
128
+ message = {"role": "assistant", "content": res, "avatar": '🤖'}
129
+ st.session_state.messages.append(message)
130
+ st.markdown(res)
131
+
132
+ send_amplitude_data(
133
+ user_query=st.session_state.messages[-2]["content"],
134
+ bot_response=st.session_state.messages[-1]["content"],
135
+ demo_name=cfg['demo_name']
136
+ )
137
+
138
+ st.session_state.ex_prompt = None
139
+ st.session_state.prompt = None
140
+ st.session_state.first_turn = False
141
+ st.rerun()
142
+
143
+ # Record user feedback
144
+ if (st.session_state.messages[-1]["role"] == "assistant") & (st.session_state.messages[-1]["content"] != initial_prompt):
145
+ streamlit_feedback(
146
+ feedback_type="thumbs", on_submit = thumbs_feedback, key = st.session_state.feedback_key,
147
+ kwargs = {"user_query": st.session_state.messages[-2]["content"],
148
+ "bot_response": st.session_state.messages[-1]["content"],
149
+ "demo_name": cfg["demo_name"]}
150
+ )
151
+
152
+ log_placeholder = st.empty()
153
+ with log_placeholder.container():
154
+ if st.session_state.show_logs:
155
+ st.button("Hide Logs", on_click=toggle_logs)
156
+ for msg in st.session_state.log_messages:
157
+ st.text(msg)
158
+ else:
159
+ if len(st.session_state.log_messages) > 0:
160
+ st.button("Show Logs", on_click=toggle_logs)
161
+
162
+ sys.stdout.flush()
163
+
164
+ def setup_db():
165
+ db_path = 'cfpb_database.db'
166
+ conn = sqlite3.connect(db_path)
167
+ cursor = conn.cursor()
168
+
169
+ with st.spinner("Loading data... Please wait..."):
170
+ def table_populated() -> bool:
171
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='cfpb_complaints'")
172
+ result = cursor.fetchone()
173
+ if not result:
174
+ return False
175
+ return True
176
+
177
+ if table_populated():
178
+ print("Database table already populated, skipping setup")
179
+ conn.close()
180
+ return
181
+ else:
182
+ print("Populating database table")
183
+
184
+ # Execute the SQL commands to create the database table
185
+ with open('create_table.sql', 'r') as sql_file:
186
+ sql_script = sql_file.read()
187
+ cursor.executescript(sql_script)
188
+
189
+ hf_token = os.getenv('HF_TOKEN')
190
+
191
+ # Load data into cfpb_complaints table
192
+ df = load_dataset("vectara/cfpb-complaints", data_files="cfpb_complaints.csv", token=hf_token)['train'].to_pandas()
193
+ df.to_sql('cfpb_complaints', conn, if_exists='replace', index=False)
194
+
195
+ # Commit changes and close connection
196
+ conn.commit()
197
+ conn.close()
198
+
199
+ if __name__ == "__main__":
200
+ st.set_page_config(page_title="CFPB Complaints Assistant", layout="wide")
201
+ setup_db()
202
+
203
+ nest_asyncio.apply()
204
+ asyncio.run(launch_bot())
create_table.sql ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CREATE TABLE cfpb_complanints (
2
+ complaint_id INTEGER PRIMARY KEY,
3
+ company VARCHAR(37),
4
+ state VARCHAR(2),
5
+ zip_code INTEGER,
6
+ product VARCHAR(76),
7
+ sub_product VARCHAR(48),
8
+ issue VARCHAR(80),
9
+ sub_issue VARCHAR(145),
10
+ date_submitted TEXT,
11
+ date_received TEXT,
12
+ report_method VARCHAR(12),
13
+ complaint_status VARCHAR(31),
14
+ timely_response INTEGER
15
+ );
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ python-dotenv==1.0.1
3
+ streamlit==1.32.2
4
+ streamlit_pills==0.3.0
5
+ streamlit-feedback==0.1.3
6
+ langdetect==1.0.9
7
+ langcodes==3.4.0
8
+ datasets==2.19.2
9
+ uuid==1.30
10
+ vectara-agentic==0.1.15
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import re
5
+
6
+ import streamlit as st
7
+
8
+ from langdetect import detect_langs
9
+ from langcodes import Language
10
+
11
+ headers = {
12
+ 'Content-Type': 'application/json',
13
+ 'Accept': '*/*'
14
+ }
15
+
16
+ def identify_language(response):
17
+ lang_code = detect_langs(response)[0].lang
18
+ return Language.make(language=lang_code).display_name()
19
+
20
+ def thumbs_feedback(feedback, **kwargs):
21
+ """
22
+ Sends feedback to Amplitude Analytics
23
+ """
24
+ send_amplitude_data(
25
+ user_query=kwargs.get("user_query", "No user input"),
26
+ bot_response=kwargs.get("bot_response", "No bot response"),
27
+ demo_name=kwargs.get("demo_name", "Unknown"),
28
+ feedback=feedback['score'],
29
+ )
30
+ st.session_state.feedback_key += 1
31
+
32
+ def send_amplitude_data(user_query, bot_response, demo_name, feedback=None):
33
+ # Send query and response to Amplitude Analytics
34
+ data = {
35
+ "api_key": os.getenv('AMPLITUDE_TOKEN'),
36
+ "events": [{
37
+ "device_id": st.session_state.device_id,
38
+ "event_type": "submitted_query",
39
+ "event_properties": {
40
+ "Space Name": demo_name,
41
+ "Demo Type": "Agent",
42
+ "query": user_query,
43
+ "response": bot_response,
44
+ "Response Language": identify_language(bot_response)
45
+ }
46
+ }]
47
+ }
48
+ if feedback:
49
+ data["events"][0]["event_properties"]["feedback"] = feedback
50
+
51
+ response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
52
+ if response.status_code != 200:
53
+ print(f"Amplitude request failed with status code {response.status_code}. Response Text: {response.text}")
54
+
55
+ def escape_dollars_outside_latex(text):
56
+ # Define a regex pattern to find LaTeX equations (double $$ only)
57
+ pattern = r'\$\$.*?\$\$'
58
+ latex_matches = re.findall(pattern, text, re.DOTALL)
59
+
60
+ # Placeholder to temporarily store LaTeX equations
61
+ placeholders = {}
62
+ for i, match in enumerate(latex_matches):
63
+ placeholder = f'__LATEX_PLACEHOLDER_{i}__'
64
+ placeholders[placeholder] = match
65
+ text = text.replace(match, placeholder)
66
+
67
+ # Escape dollar signs in the rest of the text
68
+ text = text.replace('$', '\\$')
69
+
70
+ # Replace placeholders with the original LaTeX equations
71
+ for placeholder, original in placeholders.items():
72
+ text = text.replace(placeholder, original)
73
+ return text
74
+