shresthasingh commited on
Commit
c7ea556
1 Parent(s): d97eaee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import requests
4
+ import json
5
+ import gradio as gr
6
+ import PyPDF2
7
+ import chromadb
8
+ import csv
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+
12
+ # Constants
13
+ API_KEY = "c0165440493846b339438fab762683835cf8b78a9c2d3c1216555e491565ca6a"
14
+ BASE_URL = "https://api.together.xyz/v1/chat/completions"
15
+ CHUNK_SIZE = 6000 # Maximum words per chunk
16
+ TEMP_SUMMARY_FILE = "temp_summaries.txt"
17
+ COLLECTIONS_FILE = "collections.csv"
18
+
19
+ # Function to convert PDF to text
20
+ def pdf_to_text(file_path):
21
+ with open(file_path, 'rb') as pdf_file:
22
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
23
+ text = ""
24
+ for page in pdf_reader.pages:
25
+ text += page.extract_text()
26
+ return text
27
+
28
+ # Function to summarize text using LLM
29
+ def summarize_text(text):
30
+ user_prompt = f"""
31
+ You are an expert in legal language and document summarization. Your task is to provide a concise and accurate summary of the given document.
32
+ Keep the summary concise, ideally in 2000 words, while covering all essential points. Here is the document to summarize:
33
+
34
+ {text}
35
+ """
36
+
37
+ return call_llm(user_prompt)
38
+
39
+ # Function to handle file upload, summarization, and saving to ChromaDB
40
+ def handle_file_upload(files, collection_name):
41
+ if not collection_name:
42
+ return "Please provide a collection name."
43
+
44
+ os.makedirs('uploaded_pdfs', exist_ok=True)
45
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=100)
46
+ embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
47
+
48
+ client = chromadb.PersistentClient(path="./db")
49
+ try:
50
+ collection = client.create_collection(name=collection_name)
51
+ except ValueError as e:
52
+ return f"Error creating collection: {str(e)}. Please try a different collection name."
53
+
54
+ file_names = []
55
+ with open(TEMP_SUMMARY_FILE, 'w', encoding='utf-8') as temp_file:
56
+ for file in files:
57
+ file_name = os.path.basename(file.name)
58
+ file_names.append(file_name)
59
+ file_path = os.path.join('uploaded_pdfs', file_name)
60
+ shutil.copy(file.name, file_path)
61
+
62
+ text = pdf_to_text(file_path)
63
+ chunks = text_splitter.split_text(text)
64
+
65
+ for i, chunk in enumerate(chunks):
66
+ summary = summarize_text(chunk)
67
+ temp_file.write(f"Summary of {file_name} (Part {i+1}):\n{summary}\n\n")
68
+
69
+ # Process the temporary file and add to ChromaDB
70
+ with open(TEMP_SUMMARY_FILE, 'r', encoding='utf-8') as temp_file:
71
+ summaries = temp_file.read()
72
+ summary_chunks = text_splitter.split_text(summaries)
73
+
74
+ for i, chunk in enumerate(summary_chunks):
75
+ vector = embeddings.embed_query(chunk)
76
+ collection.add(
77
+ embeddings=[vector],
78
+ documents=[chunk],
79
+ ids=[f"summary_{i}"]
80
+ )
81
+
82
+ os.remove(TEMP_SUMMARY_FILE)
83
+
84
+ # Update collections.csv
85
+ update_collections_csv(collection_name, file_names)
86
+
87
+ return "Files uploaded, summarized, and processed successfully."
88
+
89
+ # Function to update collections.csv
90
+ def update_collections_csv(collection_name, file_names):
91
+ file_names_str = ", ".join(file_names)
92
+ with open(COLLECTIONS_FILE, 'a', newline='') as csvfile:
93
+ writer = csv.writer(csvfile)
94
+ writer.writerow([collection_name, file_names_str])
95
+
96
+ # Function to read collections.csv
97
+ def read_collections():
98
+ if not os.path.exists(COLLECTIONS_FILE):
99
+ return "No collections found."
100
+
101
+ with open(COLLECTIONS_FILE, 'r') as csvfile:
102
+ reader = csv.reader(csvfile)
103
+ collections = [f"Collection: {row[0]}\nFiles: {row[1]}\n\n" for row in reader]
104
+
105
+ return "".join(collections)
106
+
107
+ # Function to search vector database
108
+ def search_vector_database(query, collection_name):
109
+ if not collection_name:
110
+ return "Please provide a collection name."
111
+
112
+ embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
113
+ client = chromadb.PersistentClient(path="./db")
114
+ try:
115
+ collection = client.get_collection(name=collection_name)
116
+ except ValueError as e:
117
+ return f"Error accessing collection: {str(e)}. Make sure the collection name is correct."
118
+
119
+ query_vector = embeddings.embed_query(query)
120
+ results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"])
121
+
122
+ return "\n\n".join(results["documents"][0])
123
+
124
+ # Function to call LLM
125
+ def call_llm(prompt):
126
+ headers = {
127
+ "Authorization": f"Bearer {API_KEY}",
128
+ "Content-Type": "application/json"
129
+ }
130
+
131
+ data = {
132
+ "model": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
133
+ "messages": [{"role": "user", "content": prompt}],
134
+ "temperature": 0.7,
135
+ "top_p": 0.7,
136
+ "top_k": 50,
137
+ "repetition_penalty": 1,
138
+ "stop": ["\"\""],
139
+ "stream": False
140
+ }
141
+
142
+ response = requests.post(BASE_URL, headers=headers, data=json.dumps(data))
143
+ response.raise_for_status()
144
+ return response.json()['choices'][0]['message']['content']
145
+
146
+ # Function to answer questions using Rachel.AI
147
+ def answer_question(question, collection_name):
148
+ context = search_vector_database(question, collection_name)
149
+
150
+ prompt = f"""
151
+ You are a paralegal AI assistant. Your role is to assist with legal inquiries by providing clear and concise answers based on the provided question and legal context. Always maintain a highly professional tone, ensuring that your responses are well-reasoned and legally accurate.
152
+ Question: {question}
153
+ Legal Context: {context}
154
+ Please provide a detailed response considering the above information.
155
+ """
156
+
157
+ return call_llm(prompt)
158
+
159
+ # Gradio interface
160
+ def gradio_interface():
161
+ with gr.Blocks(theme='gl198976/The-Rounded') as interface:
162
+ gr.Markdown("# rachel.ai backend")
163
+
164
+ gr.Markdown("""
165
+ ### Warning
166
+ If you encounter an error when uploading files, try changing the collection name and upload again.
167
+ Each collection name must be unique.
168
+ """)
169
+
170
+ with gr.Tab("Document Upload and Search"):
171
+ with gr.Row():
172
+ with gr.Column():
173
+ collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection")
174
+ file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs")
175
+ upload_btn = gr.Button("Upload, Summarize, and Process Files")
176
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
177
+ with gr.Column():
178
+ search_query_input = gr.Textbox(label="Search Query")
179
+ search_collection_name = gr.Textbox(label="Collection Name for Search", placeholder="Enter the collection name to search")
180
+ search_output = gr.Textbox(label="Search Results", lines=10)
181
+ search_btn = gr.Button("Search")
182
+
183
+ api_details = gr.Markdown("""
184
+ ### API Endpoint Details
185
+ - **URL:** http://0.0.0.0:7860/search_vector_database
186
+ - **Method:** POST
187
+ - **Example Usage:**
188
+
189
+ ```python
190
+ from gradio_client import Client
191
+
192
+ client = Client("http://0.0.0.0:7860/")
193
+ result = client.predict(
194
+ "search query", # str in 'Search Query' Textbox component
195
+ "name of collection given in ui", # str in 'Collection Name' Textbox component
196
+ api_name="/search_vector_database"
197
+ )
198
+ print(result)
199
+ ```
200
+ """)
201
+
202
+ with gr.Tab("Rachel.AI"):
203
+ question_input = gr.Textbox(label="Ask a question")
204
+ rachel_collection_name = gr.Textbox(label="Collection Name", placeholder="Enter the collection name to search")
205
+ answer_output = gr.Textbox(label="Answer", lines=10)
206
+ ask_btn = gr.Button("Ask Rachel.AI")
207
+
208
+ rachel_api_details = gr.Markdown("""
209
+ ### API Endpoint Details for Rachel.AI
210
+ - **URL:** http://0.0.0.0:7860/answer_question
211
+ - **Method:** POST
212
+ - **Example Usage:**
213
+
214
+ ```python
215
+ from gradio_client import Client
216
+
217
+ client = Client("http://0.0.0.0:7860/")
218
+ result = client.predict(
219
+ "question", # str in 'Ask a question' Textbox component
220
+ "collection_name", # str in 'Collection Name' Textbox component
221
+ api_name="/answer_question"
222
+ )
223
+ print(result)
224
+ ```
225
+ """)
226
+
227
+ with gr.Tab("Collections"):
228
+ collections_output = gr.Textbox(label="Collections and Files", lines=20)
229
+ refresh_btn = gr.Button("Refresh Collections")
230
+
231
+ upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status])
232
+ search_btn.click(search_vector_database, inputs=[search_query_input, search_collection_name], outputs=[search_output])
233
+ ask_btn.click(answer_question, inputs=[question_input, rachel_collection_name], outputs=[answer_output])
234
+ refresh_btn.click(read_collections, inputs=[], outputs=[collections_output])
235
+
236
+ interface.launch(server_name="0.0.0.0", server_port=7860)
237
+
238
+ if __name__ == "__main__":
239
+ gradio_interface()