Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from llama_parse import LlamaParse | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.schema import Document as LangchainDocument | |
# Initialize global variables | |
vs_dict = {} | |
# Helper function to load and parse the input data | |
def mariela_parse(files): | |
parser = LlamaParse( | |
api_key=os.getenv("LLAMA_API_KEY"), | |
result_type="markdown", | |
verbose=True | |
) | |
parsed_documents = [] | |
for file in files: | |
parsed_documents.extend(parser.load_data(file.name)) | |
return parsed_documents | |
# Create vector database | |
def mariela_create_vector_database(parsed_documents, collection_name): | |
langchain_docs = [ | |
LangchainDocument(page_content=doc.text, metadata=doc.metadata) | |
for doc in parsed_documents | |
] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=100) | |
docs = text_splitter.split_documents(langchain_docs) | |
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
vs = Chroma.from_documents( | |
documents=docs, | |
embedding=embed_model, | |
persist_directory="chroma_db", | |
collection_name=collection_name | |
) | |
return vs | |
# Function to handle file upload and parsing | |
def mariela_upload_and_parse(files, collection_name): | |
global vs_dict | |
if not files: | |
return "Please upload at least one file." | |
parsed_documents = mariela_parse(files) | |
vs = mariela_create_vector_database(parsed_documents, collection_name) | |
vs_dict[collection_name] = vs | |
return f"Files uploaded, parsed, and stored successfully in collection: {collection_name}" | |
# Function to handle retrieval | |
def mariela_retrieve(question, collection_name): | |
global vs_dict | |
if collection_name not in vs_dict: | |
return f"Collection '{collection_name}' not found. Please upload and parse files for this collection first." | |
vs = vs_dict[collection_name] | |
results = vs.similarity_search(question, k=4) | |
formatted_results = [] | |
for i, doc in enumerate(results, 1): | |
formatted_results.append(f"Result {i}:\n{doc.page_content}\n\nMetadata: {doc.metadata}\n") | |
return "\n\n".join(formatted_results) | |
# Supported file types list | |
supported_file_types = """ | |
Supported Document Types: | |
- Base types: pdf | |
- Documents and presentations: 602, abw, cgm, cwk, doc, docx, docm, dot, dotm, hwp, key, lwp, mw, mcw, pages, pbd, ppt, pptm, pptx, pot, potm, potx, rtf, sda, sdd, sdp, sdw, sgl, sti, sxi, sxw, stw, sxg, txt, uof, uop, uot, vor, wpd, wps, xml, zabw, epub | |
- Images: jpg, jpeg, png, gif, bmp, svg, tiff, webp, web, htm, html | |
- Spreadsheets: xlsx, xls, xlsm, xlsb, xlw, csv, dif, sylk, slk, prn, numbers, et, ods, fods, uos1, uos2, dbf, wk1, wk2, wk3, wk4, wks, 123, wq1, wq2, wb1, wb2, wb3, qpw, xlr, eth, tsv | |
""" | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Mariela: Multi-Action Retrieval and Intelligent Extraction Learning Assistant") | |
gr.Markdown("This application allows you to upload documents, parse them, and then ask questions to retrieve relevant information.") | |
with gr.Tab("Upload and Parse Files"): | |
gr.Markdown("## Upload and Parse Files") | |
gr.Markdown("Upload your documents here to create a searchable knowledge base.") | |
gr.Markdown(""" | |
### API Documentation | |
1. **Confirm that you have cURL installed on your system.** | |
```bash | |
$ curl --version | |
``` | |
2. **Find the API endpoint below corresponding to your desired function in the app.** | |
**API Name: `/mariela_upload`** | |
```bash | |
curl -X POST {url_of_gradio_app}/call/mariela_upload -s -H "Content-Type: application/json" -d '{ | |
"data": [ | |
[handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf')], | |
"Hello!!" | |
]}' \ | |
| awk -F'"' '{ print $4}' \ | |
| read EVENT_ID; curl -N {url_of_gradio_app}/call/mariela_upload/$EVENT_ID | |
``` | |
**Accepts 2 parameters:** | |
- **[0] any (Required):** The input value that is provided in the "Upload Files" File component. | |
- **[1] string (Required):** The input value that is provided in the "Collection Name" Textbox component. | |
**Returns 1 element:** | |
- **string:** The output value that appears in the "Status" Textbox component. | |
""") | |
file_input = gr.File(label="Upload Files", file_count="multiple") | |
collection_name_input = gr.Textbox(label="Collection Name") | |
upload_button = gr.Button("Upload and Parse") | |
upload_output = gr.Textbox(label="Status") | |
upload_button.click(mariela_upload_and_parse, inputs=[file_input, collection_name_input], outputs=upload_output) | |
with gr.Tab("Retrieval"): | |
gr.Markdown("## Retrieval") | |
gr.Markdown("Ask questions about your uploaded documents here.") | |
gr.Markdown(""" | |
### API Documentation | |
1. **Confirm that you have cURL installed on your system.** | |
```bash | |
$ curl --version | |
``` | |
2. **Find the API endpoint below corresponding to your desired function in the app.** | |
**API Name: `/mariela_retrieve`** | |
```bash | |
curl -X POST {url_of_gradio_app}/call/mariela_retrieve -s -H "Content-Type: application/json" -d '{ | |
"data": [ | |
"Hello!!", | |
"Hello!!" | |
]}' \ | |
| awk -F'"' '{ print $4}' \ | |
| read EVENT_ID; curl -N {url_of_gradio_app}/call/mariela_retrieve/$EVENT_ID | |
``` | |
**Accepts 2 parameters:** | |
- **[0] string (Required):** The input value that is provided in the "Enter a query to retrieve relevant passages" Textbox component. | |
- **[1] string (Required):** The input value that is provided in the "Collection Name" Textbox component. | |
**Returns 1 element:** | |
- **string:** The output value that appears in the "Retrieved Passages" Textbox component. | |
""") | |
collection_name_retrieval = gr.Textbox(label="Collection Name") | |
question_input = gr.Textbox(label="Enter a query to retrieve relevant passages") | |
retrieval_output = gr.Textbox(label="Retrieved Passages") | |
retrieval_button = gr.Button("Retrieve") | |
retrieval_button.click(mariela_retrieve, inputs=[question_input, collection_name_retrieval], outputs=retrieval_output) | |
with gr.Tab("Supported Document Types"): | |
gr.Markdown("## Supported Document Types") | |
gr.Markdown(supported_file_types) | |
demo.launch(debug=True) |