Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import requests | |
import json | |
import gradio as gr | |
import PyPDF2 | |
import chromadb | |
import csv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# Constants | |
API_KEY = os.getenv("togetherai") | |
BASE_URL = "https://api.together.xyz/v1/chat/completions" | |
CHUNK_SIZE = 6000 # Maximum words per chunk | |
TEMP_SUMMARY_FILE = "temp_summaries.txt" | |
COLLECTIONS_FILE = "collections.csv" | |
# Function to convert PDF to text | |
def pdf_to_text(file_path): | |
with open(file_path, 'rb') as pdf_file: | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() | |
return text | |
# Function to summarize text using LLM | |
def summarize_text(text): | |
user_prompt = f""" | |
You are an expert in legal language and document summarization. Your task is to provide a concise and accurate summary of the given document. | |
Keep the summary concise, ideally in 2000 words, while covering all essential points. Here is the document to summarize: | |
{text} | |
""" | |
return call_llm(user_prompt) | |
# Function to handle file upload, summarization, and saving to ChromaDB | |
def handle_file_upload(files, collection_name): | |
if not collection_name: | |
return "Please provide a collection name." | |
os.makedirs('uploaded_pdfs', exist_ok=True) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=100) | |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
client = chromadb.PersistentClient(path="./db") | |
try: | |
collection = client.create_collection(name=collection_name) | |
except ValueError as e: | |
return f"Error creating collection: {str(e)}. Please try a different collection name." | |
file_names = [] | |
with open(TEMP_SUMMARY_FILE, 'w', encoding='utf-8') as temp_file: | |
for file in files: | |
file_name = os.path.basename(file.name) | |
file_names.append(file_name) | |
file_path = os.path.join('uploaded_pdfs', file_name) | |
shutil.copy(file.name, file_path) | |
text = pdf_to_text(file_path) | |
chunks = text_splitter.split_text(text) | |
for i, chunk in enumerate(chunks): | |
summary = summarize_text(chunk) | |
temp_file.write(f"Summary of {file_name} (Part {i+1}):\n{summary}\n\n") | |
# Process the temporary file and add to ChromaDB | |
with open(TEMP_SUMMARY_FILE, 'r', encoding='utf-8') as temp_file: | |
summaries = temp_file.read() | |
summary_chunks = text_splitter.split_text(summaries) | |
for i, chunk in enumerate(summary_chunks): | |
vector = embeddings.embed_query(chunk) | |
collection.add( | |
embeddings=[vector], | |
documents=[chunk], | |
ids=[f"summary_{i}"] | |
) | |
os.remove(TEMP_SUMMARY_FILE) | |
# Update collections.csv | |
update_collections_csv(collection_name, file_names) | |
return "Files uploaded, summarized, and processed successfully." | |
# Function to update collections.csv | |
def update_collections_csv(collection_name, file_names): | |
file_names_str = ", ".join(file_names) | |
with open(COLLECTIONS_FILE, 'a', newline='') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow([collection_name, file_names_str]) | |
# Function to read collections.csv | |
def read_collections(): | |
if not os.path.exists(COLLECTIONS_FILE): | |
return "No collections found." | |
with open(COLLECTIONS_FILE, 'r') as csvfile: | |
reader = csv.reader(csvfile) | |
collections = [f"Collection: {row[0]}\nFiles: {row[1]}\n\n" for row in reader] | |
return "".join(collections) | |
# Function to search vector database | |
def search_vector_database(query, collection_name): | |
if not collection_name: | |
return "Please provide a collection name." | |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
client = chromadb.PersistentClient(path="./db") | |
try: | |
collection = client.get_collection(name=collection_name) | |
except ValueError as e: | |
return f"Error accessing collection: {str(e)}. Make sure the collection name is correct." | |
query_vector = embeddings.embed_query(query) | |
results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"]) | |
return "\n\n".join(results["documents"][0]) | |
# Function to call LLM | |
def call_llm(prompt): | |
headers = { | |
"Authorization": f"Bearer {API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.7, | |
"top_p": 0.7, | |
"top_k": 50, | |
"repetition_penalty": 1, | |
"stop": ["\"\""], | |
"stream": False | |
} | |
response = requests.post(BASE_URL, headers=headers, data=json.dumps(data)) | |
response.raise_for_status() | |
return response.json()['choices'][0]['message']['content'] | |
# Function to answer questions using Rachel.AI | |
def answer_question(question, collection_name): | |
context = search_vector_database(question, collection_name) | |
prompt = f""" | |
You are a paralegal AI assistant. Your role is to assist with legal inquiries by providing clear and concise answers based on the provided question and legal context. Always maintain a highly professional tone, ensuring that your responses are well-reasoned and legally accurate. | |
Question: {question} | |
Legal Context: {context} | |
Please provide a detailed response considering the above information. also when answering the question make sure to inform the user which filename you are using (filename is often given in the context) | |
""" | |
return call_llm(prompt) | |
# Gradio interface | |
def gradio_interface(): | |
with gr.Blocks(theme='gl198976/The-Rounded') as interface: | |
gr.Markdown("# rachel.ai backend") | |
gr.Markdown(""" | |
### Warning | |
If you encounter an error when uploading files, try changing the collection name and upload again. | |
Each collection name must be unique. | |
""") | |
with gr.Tab("Document Upload and Search"): | |
with gr.Row(): | |
with gr.Column(): | |
collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection") | |
file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs") | |
upload_btn = gr.Button("Upload, Summarize, and Process Files") | |
upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
with gr.Column(): | |
search_query_input = gr.Textbox(label="Search Query") | |
search_collection_name = gr.Textbox(label="Collection Name for Search", placeholder="Enter the collection name to search") | |
search_output = gr.Textbox(label="Search Results", lines=10) | |
search_btn = gr.Button("Search") | |
api_details = gr.Markdown(""" | |
### API Endpoint Details | |
- **URL:** http://0.0.0.0:7860/search_vector_database | |
- **Method:** POST | |
- **Example Usage:** | |
```python | |
from gradio_client import Client | |
client = Client("http://0.0.0.0:7860/") | |
result = client.predict( | |
"search query", # str in 'Search Query' Textbox component | |
"name of collection given in ui", # str in 'Collection Name' Textbox component | |
api_name="/search_vector_database" | |
) | |
print(result) | |
``` | |
""") | |
with gr.Tab("Rachel.AI"): | |
question_input = gr.Textbox(label="Ask a question") | |
rachel_collection_name = gr.Textbox(label="Collection Name", placeholder="Enter the collection name to search") | |
answer_output = gr.Textbox(label="Answer", lines=10) | |
ask_btn = gr.Button("Ask Rachel.AI") | |
rachel_api_details = gr.Markdown(""" | |
### API Endpoint Details for Rachel.AI | |
- **URL:** http://0.0.0.0:7860/answer_question | |
- **Method:** POST | |
- **Example Usage:** | |
```python | |
from gradio_client import Client | |
client = Client("http://0.0.0.0:7860/") | |
result = client.predict( | |
"question", # str in 'Ask a question' Textbox component | |
"collection_name", # str in 'Collection Name' Textbox component | |
api_name="/answer_question" | |
) | |
print(result) | |
``` | |
""") | |
with gr.Tab("Collections"): | |
collections_output = gr.Textbox(label="Collections and Files", lines=20) | |
refresh_btn = gr.Button("Refresh Collections") | |
upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status]) | |
search_btn.click(search_vector_database, inputs=[search_query_input, search_collection_name], outputs=[search_output]) | |
ask_btn.click(answer_question, inputs=[question_input, rachel_collection_name], outputs=[answer_output]) | |
refresh_btn.click(read_collections, inputs=[], outputs=[collections_output]) | |
interface.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
gradio_interface() |