File size: 8,490 Bytes
35e9cba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
from dotenv import load_dotenv
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Neo4jVector
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.graphs import Neo4jGraph
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
import streamlit as st
import tempfile
from neo4j import GraphDatabase
def main():
st.set_page_config(
layout="wide",
page_title="Graphy v1",
page_icon=":graph:"
)
st.sidebar.image('GRAP.png', use_column_width=True)
with st.sidebar.expander("Expand Me"):
st.markdown("""
This application allows you to upload a PDF file, extract its content into a Neo4j graph database, and perform queries using natural language.
It leverages LangChain and OpenAI's GPT models to generate Cypher queries that interact with the Neo4j database in real-time.
""")
st.title("Graphy: Realtime GraphRAG App")
load_dotenv()
# Set OpenAI API key
if 'OPENAI_API_KEY' not in st.session_state:
st.sidebar.subheader("OpenAI API Key")
openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type='password')
if openai_api_key:
os.environ['OPENAI_API_KEY'] = openai_api_key
st.session_state['OPENAI_API_KEY'] = openai_api_key
st.sidebar.success("OpenAI API Key set successfully.")
embeddings = OpenAIEmbeddings()
llm = ChatOpenAI(model_name="gpt-4o") # Use model that supports function calling
st.session_state['embeddings'] = embeddings
st.session_state['llm'] = llm
else:
embeddings = st.session_state['embeddings']
llm = st.session_state['llm']
# Initialize variables
neo4j_url = None
neo4j_username = None
neo4j_password = None
graph = None
# Set Neo4j connection details
if 'neo4j_connected' not in st.session_state:
st.sidebar.subheader("Connect to Neo4j Database")
neo4j_url = st.sidebar.text_input("Neo4j URL:", value="neo4j+s://<your-neo4j-url>")
neo4j_username = st.sidebar.text_input("Neo4j Username:", value="neo4j")
neo4j_password = st.sidebar.text_input("Neo4j Password:", type='password')
connect_button = st.sidebar.button("Connect")
if connect_button and neo4j_password:
try:
graph = Neo4jGraph(
url=neo4j_url,
username=neo4j_username,
password=neo4j_password
)
st.session_state['graph'] = graph
st.session_state['neo4j_connected'] = True
# Store connection parameters for later use
st.session_state['neo4j_url'] = neo4j_url
st.session_state['neo4j_username'] = neo4j_username
st.session_state['neo4j_password'] = neo4j_password
st.sidebar.success("Connected to Neo4j database.")
except Exception as e:
st.error(f"Failed to connect to Neo4j: {e}")
else:
graph = st.session_state['graph']
neo4j_url = st.session_state['neo4j_url']
neo4j_username = st.session_state['neo4j_username']
neo4j_password = st.session_state['neo4j_password']
# Ensure that the Neo4j connection is established before proceeding
if graph is not None:
# File uploader
uploaded_file = st.file_uploader("Please select a PDF file.", type="pdf")
if uploaded_file is not None and 'qa' not in st.session_state:
with st.spinner("Processing the PDF..."):
# Save uploaded file to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file_path = tmp_file.name
# Load and split the PDF
loader = PyPDFLoader(tmp_file_path)
pages = loader.load_and_split()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
docs = text_splitter.split_documents(pages)
lc_docs = []
for doc in docs:
lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""),
metadata={'source': uploaded_file.name}))
# Clear the graph database
cypher = """
MATCH (n)
DETACH DELETE n;
"""
graph.query(cypher)
# Define allowed nodes and relationships
allowed_nodes = ["Patient", "Disease", "Medication", "Test", "Symptom", "Doctor"]
allowed_relationships = ["HAS_DISEASE", "TAKES_MEDICATION", "UNDERWENT_TEST", "HAS_SYMPTOM", "TREATED_BY"]
# Transform documents into graph documents
transformer = LLMGraphTransformer(
llm=llm,
allowed_nodes=allowed_nodes,
allowed_relationships=allowed_relationships,
node_properties=False,
relationship_properties=False
)
graph_documents = transformer.convert_to_graph_documents(lc_docs)
graph.add_graph_documents(graph_documents, include_source=True)
# Use the stored connection parameters
index = Neo4jVector.from_existing_graph(
embedding=embeddings,
url=neo4j_url,
username=neo4j_username,
password=neo4j_password,
database="neo4j",
node_label="Patient", # Adjust node_label as needed
text_node_properties=["id", "text"],
embedding_node_property="embedding",
index_name="vector_index",
keyword_index_name="entity_index",
search_type="hybrid"
)
st.success(f"{uploaded_file.name} preparation is complete.")
# Retrieve the graph schema
schema = graph.get_schema
# Set up the QA chain
template = """
Task: Generate a Cypher statement to query the graph database.
Instructions:
Use only relationship types and properties provided in schema.
Do not use other relationship types or properties that are not provided.
schema:
{schema}
Note: Do not include explanations or apologies in your answers.
Do not answer questions that ask anything other than creating Cypher statements.
Do not include any text other than generated Cypher statements.
Question: {question}"""
question_prompt = PromptTemplate(
template=template,
input_variables=["schema", "question"]
)
qa = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
cypher_prompt=question_prompt,
verbose=True,
allow_dangerous_requests=True
)
st.session_state['qa'] = qa
else:
st.warning("Please connect to the Neo4j database before you can upload a PDF.")
if 'qa' in st.session_state:
st.subheader("Ask a Question")
with st.form(key='question_form'):
question = st.text_input("Enter your question:")
submit_button = st.form_submit_button(label='Submit')
if submit_button and question:
with st.spinner("Generating answer..."):
res = st.session_state['qa'].invoke({"query": question})
st.write("\n**Answer:**\n" + res['result'])
if __name__ == "__main__":
main()
|