import os from typing import List import pinecone from tqdm.auto import tqdm from uuid import uuid4 import arxiv from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore from langchain.vectorstores import Pinecone INDEX_BATCH_LIMIT = 100 class CharacterTextSplitter: def __init__( self, chunk_size: int = 1000, chunk_overlap: int = 200, ): assert ( chunk_size > chunk_overlap ), "Chunk size must be greater than chunk overlap" self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.text_splitter = RecursiveCharacterTextSplitter( chunk_size = self.chunk_size, # the character length of the chunk chunk_overlap = self.chunk_overlap, # the character length of the overlap between chunks length_function = len, # the length function - in this case, character length (aka the python len() fn.) ) def split(self, text: str) -> List[str]: return self.text_splitter.split_text(text) class ArxivLoader: def __init__(self, query : str = "Nuclear Fission", max_results : int = 5, encoding: str = "utf-8"): """""" self.query = query self.max_results = max_results self.paper_urls = [] self.documents = [] self.splitter = CharacterTextSplitter() def retrieve_urls(self): """""" arxiv_client = arxiv.Client() search = arxiv.Search( query = self.query, max_results = self.max_results, sort_by = arxiv.SortCriterion.Relevance ) for result in arxiv_client.results(search): self.paper_urls.append(result.pdf_url) def load_documents(self): """""" for paper_url in self.paper_urls: loader = PyPDFLoader(paper_url) self.documents.append(loader.load()) def format_document(self, document): """""" metadata = { 'source_document' : document.metadata["source"], 'page_number' : document.metadata["page"] } record_texts = self.splitter.split(document.page_content) record_metadatas = [{ "chunk": j, "text": text, **metadata } for j, text in enumerate(record_texts)] return record_texts, record_metadatas def main(self): """""" self.retrieve_urls() self.load_documents() class PineconeIndexer: def __init__(self, index_name : str = "arxiv-paper-index", metric : str = "cosine", n_dims : int = 1536): """""" pinecone.init( api_key=os.environ["PINECONE_API_KEY"], environment=os.environ["PINECONE_ENV"] ) if index_name not in pinecone.list_indexes(): # we create a new index pinecone.create_index( name=index_name, metric=metric, dimension=n_dims ) self.arxiv_loader = ArxivLoader() self.index = pinecone.Index(index_name) def load_embedder(self): """""" store = LocalFileStore("./cache/") core_embeddings_model = OpenAIEmbeddings() self.embedder = CacheBackedEmbeddings.from_bytes_store( core_embeddings_model, store, namespace=core_embeddings_model.model ) def upsert(self, texts, metadatas): """""" ids = [str(uuid4()) for _ in range(len(texts))] embeds = self.embedder.embed_documents(texts) self.index.upsert(vectors=zip(ids, embeds, metadatas)) def index_documents(self, documents, batch_limit : int = INDEX_BATCH_LIMIT): """""" texts = [] metadatas = [] # iterate through your top-level document for i in tqdm(range(len(documents))): # select single document object for page in documents[i] : record_texts, record_metadatas = self.arxiv_loader.format_document(page) texts.extend(record_texts) metadatas.extend(record_metadatas) if len(texts) >= batch_limit: self.upsert(texts, metadatas) texts = [] metadatas = [] if len(texts) > 0: self.upsert(texts, metadatas) def get_vectorstore(self): """""" return Pinecone(self.index, self.embedder.embed_query, "text") if __name__ == "__main__": print("-------------- Loading Arxiv --------------") axloader = ArxivLoader() axloader.retrieve_urls() axloader.load_documents() print("\n-------------- Splitting sample doc --------------") sample_doc = axloader.documents[0] sample_page = sample_doc[0] splitter = CharacterTextSplitter() chunks = splitter.split(sample_page.page_content) print(len(chunks)) print(chunks[0]) print("\n-------------- testing pinecode indexer --------------") pi = PineconeIndexer() pi.load_embedder() pi.index_documents(axloader.documents) print(pi.index.describe_index_stats())