AkashMnd commited on
Commit
119e4cc
1 Parent(s): 09ec914

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -0
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import gradio as gr
5
+ import threading
6
+ import time
7
+ import PyPDF2
8
+ import chromadb
9
+ import shutil
10
+ from pydantic import BaseModel, Field
11
+ from typing import Dict
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_huggingface import HuggingFaceEmbeddings
14
+
15
+
16
+ API_KEY = os.getenv("mistral")
17
+ BASE_URL = "https://api.together.xyz"
18
+
19
+ # Store user inputs
20
+ user_inputs = {
21
+ "organization": "",
22
+ "rules_l1": "",
23
+ "rules_l2": "",
24
+ "rules_l3": "",
25
+ }
26
+
27
+ # Function to classify query
28
+ def classify_query(query: str) -> Dict:
29
+ if not all(user_inputs.values()):
30
+ raise ValueError("Please fill all input fields first.")
31
+
32
+ messages = [
33
+ {"role": "system", "content": f"""You are a Customer Query Classification Agent for {user_inputs["organization"]}.
34
+ What is considered Level 1 Query (Requires no account info just provided documents by the admin is enough to answer):
35
+ {user_inputs["rules_l1"]}
36
+ What is considered Level 2 Query (Requires account info and provided documents by the admin is enough to answer):
37
+ {user_inputs["rules_l2"]}
38
+ What is considered as Level 3 Query (Immediate Escalation to Human Customer Service Agents):
39
+ {user_inputs["rules_l3"]}
40
+ Classify the following customer query and provide the output in JSON format:
41
+ ```json
42
+ {{
43
+ "title": "title of the query in under 10 words",
44
+ "level": "1 or 2 or 3"
45
+ }}
46
+ ```"""},
47
+
48
+ {"role": "user", "content": query}
49
+ ]
50
+
51
+ headers = {
52
+ "Content-Type": "application/json",
53
+ "Authorization": f"Bearer {API_KEY}"
54
+ }
55
+
56
+ data = {
57
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
58
+ "messages": messages,
59
+ "temperature": 0.7,
60
+ "response_format": {
61
+ "type": "json_object",
62
+ "schema": {
63
+ "type": "object",
64
+ "properties": {
65
+ "title": {"type": "string"},
66
+ "level": {"type": "integer"}
67
+ },
68
+ "required": ["title", "level"]
69
+ }
70
+ }
71
+ }
72
+
73
+ response = requests.post(f"{BASE_URL}/chat/completions", headers=headers, json=data)
74
+ response.raise_for_status()
75
+ classification_result = response.json().get('choices')[0].get('message').get('content')
76
+ return classification_result
77
+
78
+ # Function to convert PDF to text
79
+ def pdf_to_text(file_path):
80
+ pdf_file = open(file_path, 'rb')
81
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
82
+ text = ""
83
+ for page_num in range(len(pdf_reader.pages)):
84
+ text += pdf_reader.pages[page_num].extract_text()
85
+ pdf_file.close()
86
+ return text
87
+
88
+ # Function to handle file upload and save embeddings to ChromaDB
89
+ def handle_file_upload(files, collection_name):
90
+ if not collection_name:
91
+ return "Please provide a collection name."
92
+
93
+ os.makedirs('chabot_pdfs', exist_ok=True)
94
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
95
+ embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
96
+
97
+ # Initialize Chroma DB client
98
+ client = chromadb.PersistentClient(path="./db")
99
+ try:
100
+ collection = client.create_collection(name=collection_name)
101
+ except ValueError as e:
102
+ return f"Error creating collection: {str(e)}. Please try a different collection name."
103
+
104
+ for file in files:
105
+ file_name = os.path.basename(file.name)
106
+ file_path = os.path.join('chabot_pdfs', file_name)
107
+ shutil.copy(file.name, file_path) # Copy the file instead of saving
108
+ text = pdf_to_text(file_path)
109
+ chunks = text_splitter.split_text(text)
110
+
111
+ documents_list = []
112
+ embeddings_list = []
113
+ ids_list = []
114
+
115
+ for i, chunk in enumerate(chunks):
116
+ vector = embeddings.embed_query(chunk)
117
+ documents_list.append(chunk)
118
+ embeddings_list.append(vector)
119
+ ids_list.append(f"{file_name}_{i}")
120
+
121
+ collection.add(
122
+ embeddings=embeddings_list,
123
+ documents=documents_list,
124
+ ids=ids_list
125
+ )
126
+
127
+ return "Files uploaded and processed successfully."
128
+
129
+ # Function to search vector database
130
+ def search_vector_database(query, collection_name):
131
+ if not collection_name:
132
+ return "Please provide a collection name."
133
+
134
+ embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
135
+ client = chromadb.PersistentClient(path="./db")
136
+ try:
137
+ collection = client.get_collection(name=collection_name)
138
+ except ValueError as e:
139
+ return f"Error accessing collection: {str(e)}. Make sure the collection name is correct."
140
+
141
+ query_vector = embeddings.embed_query(query)
142
+ results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"])
143
+
144
+ return "\n\n".join("\n".join(result) for result in results["documents"])
145
+
146
+ # New function to handle login
147
+ def handle_login(username, password):
148
+ # This is a simple example. In a real application, you'd want to use secure authentication methods.
149
+ if username == "admin" and password == "password":
150
+ return """
151
+ "NeoBank": {
152
+ "user_id": "NB782940",
153
+ "user_name": "john_doe123",
154
+ "full_name": "John Doe",
155
+ "email": "john.doe@example.com",
156
+ "balance": 2875.43,
157
+ "transactions": [
158
+ {"date": "2024-06-20", "description": "Coffee Shop", "amount": -4.50},
159
+ {"date": "2024-06-19", "description": "Grocery Store", "amount": -85.22},
160
+ {"date": "2024-06-18", "description": "Salary Deposit", "amount": 2500.00}
161
+ ]
162
+ },
163
+ "CryptoInvest": {
164
+ "user_id": "CI549217",
165
+ "user_name": "crypto_enthusiast",
166
+ "full_name": "Alice Johnson",
167
+ "email": "alice.johnson@example.com",
168
+ "portfolio": {
169
+ "BTC": {"amount": 0.025, "value": 7500.00},
170
+ "ETH": {"amount": 1.2, "value": 2100.00},
171
+ "SOL": {"amount": 5.8, "value": 450.50}
172
+ },
173
+ "transactions": [
174
+ {"date": "2024-06-22", "description": "Bought ETH", "amount": -500.00},
175
+ {"date": "2024-06-20", "description": "Sold BTC", "amount": 1200.00}
176
+ ]
177
+ },
178
+ "RoboAdvisor": {
179
+ "user_id": "RA385712",
180
+ "user_name": "jane_smith",
181
+ "full_name": "Jane Smith",
182
+ "email": "jane.smith@example.com",
183
+ "risk_tolerance": "moderate",
184
+ "portfolio_value": 15800.75,
185
+ "allocations": {
186
+ "stocks": 0.60,
187
+ "bonds": 0.30,
188
+ "real_estate": 0.10
189
+ },
190
+ "recent_activity": [
191
+ {"date": "2024-06-21", "description": "Dividends received", "amount": 32.50},
192
+ {"date": "2024-06-15", "description": "Portfolio rebalanced" }
193
+ ]
194
+ },
195
+ "PeerLend": {
196
+ "user_id": "PL916350",
197
+ "user_name": "bob_williams",
198
+ "full_name": "Bob Williams",
199
+ "email": "bob.williams@example.com",
200
+ "account_type": "borrower",
201
+ "loan_amount": 5000.00,
202
+ "interest_rate": 7.8,
203
+ "monthly_payment": 150.30,
204
+ "payment_history": [
205
+ {"date": "2024-06-22", "status": "paid"},
206
+ {"date": "2024-05-22", "status": "paid"},
207
+ {"date": "2024-04-22", "status": "paid"}
208
+ ]
209
+ },
210
+ "InsureTech": {
211
+ "user_id": "IT264805",
212
+ "user_name": "eva_brown4",
213
+ "full_name": "Eva Brown",
214
+ "email": "eva.brown@example.com",
215
+ "policy_type": "auto",
216
+ "coverage_details": {
217
+ "liability": "50/100/50",
218
+ "collision": "500 deductible",
219
+ "comprehensive": "100 deductible"
220
+ },
221
+ "premium": 85.50,
222
+ "next_payment": "2024-07-10",
223
+ "claims": []
224
+ }
225
+ """
226
+ else:
227
+ return "Invalid username or password"
228
+
229
+ # Gradio interface
230
+ def gradio_interface():
231
+ with gr.Blocks(theme='gl198976/The-Rounded') as interface:
232
+ gr.Markdown("# Admin Dashboard🧖🏻‍♀️")
233
+
234
+ with gr.Tab("Query Classifier Agent"):
235
+ with gr.Row():
236
+ with gr.Column():
237
+ organization_input = gr.Textbox(label="Organization Name")
238
+ rules_l1_input = gr.Textbox(label="Rules for Level 1 Query", lines=5)
239
+ rules_l2_input = gr.Textbox(label="Rules for Level 2 Query", lines=5)
240
+ rules_l3_input = gr.Textbox(label="Rules for Level 3 Query", lines=5)
241
+ submit_btn = gr.Button("Submit Rules")
242
+ with gr.Column():
243
+ query_input = gr.Textbox(label="Customer Query")
244
+ classification_output = gr.Textbox(label="Classification Result")
245
+ classify_btn = gr.Button("Classify Query")
246
+ api_details = gr.Markdown("""
247
+ ### API Endpoint Details
248
+ - **URL:** `http://0.0.0.0:7860/classify`
249
+ - **Method:** POST
250
+ - **Request Body:** JSON with a single key `query`
251
+ - **Example Usage:**
252
+ ```python
253
+ from gradio_client import Client
254
+
255
+ client = Client("http://0.0.0.0:7860/")
256
+ result = client.predict(
257
+ "Hello!!", # str in 'Customer Query' Textbox component
258
+ api_name="/classify_and_display"
259
+ )
260
+ print(result)
261
+ ```
262
+ """)
263
+
264
+ submit_btn.click(lambda org, r1, r2, r3: (
265
+ setattr(user_inputs, "organization", org),
266
+ setattr(user_inputs, "rules_l1", r1),
267
+ setattr(user_inputs, "rules_l2", r2),
268
+ setattr(user_inputs, "rules_l3", r3)
269
+ ), inputs=[organization_input, rules_l1_input, rules_l2_input, rules_l3_input])
270
+
271
+ classify_btn.click(classify_query, inputs=[query_input], outputs=[classification_output])
272
+
273
+ with gr.Tab("Organization Documentation Agent"):
274
+ gr.Markdown("""
275
+ ### Warning
276
+ If you encounter an error when uploading files, try changing the collection name and upload again.
277
+ Each collection name must be unique.
278
+ """)
279
+ with gr.Row():
280
+ with gr.Column():
281
+ collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection")
282
+ file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs")
283
+ upload_btn = gr.Button("Upload and Process Files")
284
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
285
+ with gr.Column():
286
+ search_query_input = gr.Textbox(label="Search Query")
287
+ search_output = gr.Textbox(label="Search Results", lines=10)
288
+ search_btn = gr.Button("Search")
289
+ api_details = gr.Markdown("""
290
+ ### API Endpoint Details
291
+ - **URL:** `http://0.0.0.0:7860/search_vector_database`
292
+ - **Method:** POST
293
+ - **Example Usage:**
294
+ ```python
295
+ from gradio_client import Client
296
+
297
+ client = Client("http://0.0.0.0:7860/")
298
+ result = client.predict(
299
+ "search query", # str in 'Search Query' Textbox component
300
+ "name of collection given in ui", # str in 'Collection Name' Textbox component
301
+ api_name="/search_vector_database"
302
+ )
303
+ print(result)
304
+ ```
305
+ """)
306
+
307
+ upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status])
308
+ search_btn.click(search_vector_database, inputs=[search_query_input, collection_name_input], outputs=[search_output])
309
+
310
+ with gr.Tab("Account Information"):
311
+ with gr.Row():
312
+ with gr.Column():
313
+ username_input = gr.Textbox(label="Username")
314
+ password_input = gr.Textbox(label="Password", type="password")
315
+ login_btn = gr.Button("Login")
316
+ with gr.Column():
317
+ account_info_output = gr.Textbox(label="Account Info", lines=20)
318
+ api_details = gr.Markdown("""
319
+ ### API Endpoint Details
320
+ - **URL:** `http://0.0.0.0:7860/handle_login`
321
+ - **Method:** POST
322
+ - **Example Usage:**
323
+ ```python
324
+ from gradio_client import Client
325
+
326
+ client = Client("http://0.0.0.0:7860/")
327
+ result = client.predict(
328
+ "admin", # str in 'Username' Textbox component
329
+ "password", # str in 'Password' Textbox component
330
+ api_name="/handle_login"
331
+ )
332
+ print(result)
333
+ ```
334
+ """)
335
+
336
+ login_btn.click(handle_login, inputs=[username_input, password_input], outputs=[account_info_output])
337
+
338
+ interface.launch(server_name="0.0.0.0", server_port=7860)
339
+
340
+ if __name__ == "__main__":
341
+ gradio_interface()