|
import os |
|
import warnings |
|
import shutil |
|
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader, WikipediaLoader |
|
from typing import List, Optional, Dict, Any |
|
from langchain.schema import Document |
|
import chromadb |
|
|
|
from langchain_community.vectorstores import FAISS |
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
CHROMA_DB_PATH = os.path.join(os.getcwd(), "Stock Sentiment Analysis", "chroma_db") |
|
CHROMA_DB_PATHH = os.path.join(os.getcwd(), "chroma_db") |
|
|
|
tesla_10k_collection = 'tesla-10k-2019-to-2023' |
|
embedding_model = "" |
|
|
|
|
|
def clear_db(hugg = False): |
|
check_and_delete(CHROMA_DB_PATH) |
|
if hugg: |
|
check_and_delete(CHROMA_DB_PATHH) |
|
|
|
|
|
class DBStorage: |
|
def __init__(self, hugg = False): |
|
self.hugg = hugg |
|
self.CHROMA_PATH = CHROMA_DB_PATH |
|
if self.hugg: |
|
self.CHROMA_PATH = CHROMA_DB_PATHH |
|
self.vector_store = None |
|
self.client = chromadb.PersistentClient(path=self.CHROMA_PATH) |
|
print(self.client.list_collections()) |
|
self.collection = self.client.get_or_create_collection(name=tesla_10k_collection) |
|
print(self.collection.count()) |
|
|
|
def chunk_data(self, data, chunk_size=10000): |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0) |
|
return text_splitter.split_documents(data) |
|
|
|
def create_embeddings(self, chunks): |
|
embeddings = AzureOpenAIEmbeddings( |
|
model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"), |
|
api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), |
|
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") |
|
) |
|
|
|
self.vector_store = Chroma.from_documents(documents=chunks, |
|
|
|
embedding=embedding_model, |
|
collection_name=tesla_10k_collection, |
|
persist_directory=self.CHROMA_PATH) |
|
print("Here B") |
|
self.collection = self.client.get_or_create_collection(name=tesla_10k_collection) |
|
print("here"+str(self.collection.count())) |
|
|
|
|
|
def create_vector_store(self, chunks): |
|
embeddings = AzureOpenAIEmbeddings( |
|
model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"), |
|
api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), |
|
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") |
|
) |
|
return FAISS.from_documents(chunks, embedding=embeddings) |
|
|
|
|
|
|
|
def load_embeddings(self): |
|
embeddings = AzureOpenAIEmbeddings( |
|
model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"), |
|
api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), |
|
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") |
|
) |
|
|
|
self.vector_store = Chroma(collection_name=tesla_10k_collection, |
|
persist_directory=self.CHROMA_PATH, |
|
|
|
embedding_function=embedding_model |
|
) |
|
print("loaded vector store: ") |
|
print(self.vector_store) |
|
|
|
|
|
def load_vectors(self,FAISS_DB_PATH): |
|
embeddings = AzureOpenAIEmbeddings( |
|
model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"), |
|
api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), |
|
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") |
|
) |
|
|
|
self.vector_store = FAISS.load_local(folder_path=FAISS_DB_PATH, |
|
embeddings=embeddings, |
|
allow_dangerous_deserialization=True) |
|
|
|
|
|
|
|
def fetch_documents(self, metadata_filter: Dict[str, Any]) -> List[Document]: |
|
results = self.collection.get( |
|
where=metadata_filter, |
|
include=["documents", "metadatas"], |
|
) |
|
|
|
documents = [] |
|
for content, metadata in zip(results['documents'][0], results['metadatas'][0]): |
|
documents.append(Document(page_content=content, metadata=metadata)) |
|
|
|
return documents |
|
|
|
|
|
def get_context_for_query(self, question, k=3): |
|
print(self.vector_store) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relevant_document_chunks = self.vector_store.similarity_search(question) |
|
|
|
|
|
|
|
|
|
print(relevant_document_chunks) |
|
context_list = [d.page_content for d in relevant_document_chunks] |
|
context_for_query = ". ".join(context_list) |
|
print("context_for_query: "+ str(len(context_for_query))) |
|
|
|
return context_for_query |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_vectors(self,social_media_document,FAISS_DB_PATH): |
|
print("here A") |
|
chunks = self.chunk_data(social_media_document) |
|
print(len(chunks)) |
|
|
|
vector_store = self.create_vector_store(chunks) |
|
check_and_delete(FAISS_DB_PATH, self.hugg) |
|
vector_store.save_local(FAISS_DB_PATH) |
|
|
|
def check_and_delete(PATH, hugg=False): |
|
if os.path.isdir(PATH): |
|
if hugg: |
|
shutil.rmtree(PATH) |
|
else: |
|
shutil.rmtree(PATH, onexc=lambda func, path, exc: os.chmod(path, 0o777)) |
|
|
|
print(f'Deleted {PATH}') |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
qa_system = DBStorage() |
|
|
|
|
|
social_media_document = [] |
|
chunks = qa_system.chunk_data(social_media_document) |
|
|
|
|
|
qa_system.create_embeddings(chunks) |
|
|
|
|
|
|
|
|
|
|