import gradio as gr import os from llama_parse import LlamaParse from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings.fastembed import FastEmbedEmbeddings from langchain_community.vectorstores import Chroma from langchain.schema import Document as LangchainDocument # Initialize global variables vs_dict = {} # Helper function to load and parse the input data def mariela_parse(files): parser = LlamaParse( api_key=os.getenv("LLAMA_API_KEY"), result_type="markdown", verbose=True ) parsed_documents = [] for file in files: parsed_documents.extend(parser.load_data(file.name)) return parsed_documents # Create vector database def mariela_create_vector_database(parsed_documents, collection_name): langchain_docs = [ LangchainDocument(page_content=doc.text, metadata=doc.metadata) for doc in parsed_documents ] text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=100) docs = text_splitter.split_documents(langchain_docs) embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") vs = Chroma.from_documents( documents=docs, embedding=embed_model, persist_directory="chroma_db", collection_name=collection_name ) return vs # Function to handle file upload and parsing def mariela_upload_and_parse(files, collection_name): global vs_dict if not files: return "Please upload at least one file." parsed_documents = mariela_parse(files) vs = mariela_create_vector_database(parsed_documents, collection_name) vs_dict[collection_name] = vs return f"Files uploaded, parsed, and stored successfully in collection: {collection_name}" # Function to handle retrieval def mariela_retrieve(question, collection_name): global vs_dict if collection_name not in vs_dict: return f"Collection '{collection_name}' not found. Please upload and parse files for this collection first." vs = vs_dict[collection_name] results = vs.similarity_search(question, k=4) formatted_results = [] for i, doc in enumerate(results, 1): formatted_results.append(f"Result {i}:\n{doc.page_content}\n\nMetadata: {doc.metadata}\n") return "\n\n".join(formatted_results) # Supported file types list supported_file_types = """ Supported Document Types: - Base types: pdf - Documents and presentations: 602, abw, cgm, cwk, doc, docx, docm, dot, dotm, hwp, key, lwp, mw, mcw, pages, pbd, ppt, pptm, pptx, pot, potm, potx, rtf, sda, sdd, sdp, sdw, sgl, sti, sxi, sxw, stw, sxg, txt, uof, uop, uot, vor, wpd, wps, xml, zabw, epub - Images: jpg, jpeg, png, gif, bmp, svg, tiff, webp, web, htm, html - Spreadsheets: xlsx, xls, xlsm, xlsb, xlw, csv, dif, sylk, slk, prn, numbers, et, ods, fods, uos1, uos2, dbf, wk1, wk2, wk3, wk4, wks, 123, wq1, wq2, wb1, wb2, wb3, qpw, xlr, eth, tsv """ # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Mariela: Multi-Action Retrieval and Intelligent Extraction Learning Assistant") gr.Markdown("This application allows you to upload documents, parse them, and then ask questions to retrieve relevant information.") with gr.Tab("Upload and Parse Files"): gr.Markdown("## Upload and Parse Files") gr.Markdown("Upload your documents here to create a searchable knowledge base.") file_input = gr.File(label="Upload Files", file_count="multiple") collection_name_input = gr.Textbox(label="Collection Name") upload_button = gr.Button("Upload and Parse") upload_output = gr.Textbox(label="Status") upload_button.click(mariela_upload_and_parse, inputs=[file_input, collection_name_input], outputs=upload_output) with gr.Tab("Retrieval"): gr.Markdown("## Retrieval") gr.Markdown("Ask questions about your uploaded documents here.") collection_name_retrieval = gr.Textbox(label="Collection Name") question_input = gr.Textbox(label="Enter a query to retrieve relevant passages") retrieval_output = gr.Textbox(label="Retrieved Passages") retrieval_button = gr.Button("Retrieve") retrieval_button.click(mariela_retrieve, inputs=[question_input, collection_name_retrieval], outputs=retrieval_output) with gr.Tab("Supported Document Types"): gr.Markdown("## Supported Document Types") gr.Markdown(supported_file_types) demo.launch(debug=True)