|
import os |
|
from typing import List |
|
|
|
import pinecone |
|
from tqdm.auto import tqdm |
|
from uuid import uuid4 |
|
import arxiv |
|
|
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.embeddings import CacheBackedEmbeddings |
|
from langchain.storage import LocalFileStore |
|
from langchain.vectorstores import Pinecone |
|
|
|
INDEX_BATCH_LIMIT = 100 |
|
|
|
class CharacterTextSplitter: |
|
def __init__( |
|
self, |
|
chunk_size: int = 1000, |
|
chunk_overlap: int = 200, |
|
): |
|
assert ( |
|
chunk_size > chunk_overlap |
|
), "Chunk size must be greater than chunk overlap" |
|
|
|
self.chunk_size = chunk_size |
|
self.chunk_overlap = chunk_overlap |
|
|
|
self.text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size = self.chunk_size, |
|
chunk_overlap = self.chunk_overlap, |
|
length_function = len, |
|
|
|
) |
|
|
|
def split(self, text: str) -> List[str]: |
|
return self.text_splitter.split_text(text) |
|
|
|
class ArxivLoader: |
|
|
|
def __init__(self, query : str = "Nuclear Fission", max_results : int = 5, encoding: str = "utf-8"): |
|
"""""" |
|
self.query = query |
|
self.max_results = max_results |
|
|
|
self.paper_urls = [] |
|
self.documents = [] |
|
self.splitter = CharacterTextSplitter() |
|
|
|
def retrieve_urls(self): |
|
"""""" |
|
arxiv_client = arxiv.Client() |
|
search = arxiv.Search( |
|
query = self.query, |
|
max_results = self.max_results, |
|
sort_by = arxiv.SortCriterion.Relevance |
|
) |
|
|
|
for result in arxiv_client.results(search): |
|
self.paper_urls.append(result.pdf_url) |
|
|
|
def load_documents(self): |
|
"""""" |
|
for paper_url in self.paper_urls: |
|
loader = PyPDFLoader(paper_url) |
|
|
|
self.documents.append(loader.load()) |
|
|
|
def format_document(self, document): |
|
"""""" |
|
metadata = { |
|
'source_document' : document.metadata["source"], |
|
'page_number' : document.metadata["page"] |
|
} |
|
|
|
record_texts = self.splitter.split(document.page_content) |
|
record_metadatas = [{ |
|
"chunk": j, "text": text, **metadata |
|
} for j, text in enumerate(record_texts)] |
|
|
|
return record_texts, record_metadatas |
|
|
|
def main(self): |
|
"""""" |
|
self.retrieve_urls() |
|
self.load_documents() |
|
|
|
|
|
class PineconeIndexer: |
|
|
|
def __init__(self, index_name : str = "arxiv-paper-index", metric : str = "cosine", n_dims : int = 1536): |
|
"""""" |
|
pinecone.init( |
|
api_key=os.environ["PINECONE_API_KEY"], |
|
environment=os.environ["PINECONE_ENV"] |
|
) |
|
|
|
if index_name not in pinecone.list_indexes(): |
|
|
|
pinecone.create_index( |
|
name=index_name, |
|
metric=metric, |
|
dimension=n_dims |
|
) |
|
|
|
self.arxiv_loader = ArxivLoader() |
|
|
|
self.index = pinecone.Index(index_name) |
|
|
|
def load_embedder(self): |
|
"""""" |
|
store = LocalFileStore("./cache/") |
|
|
|
core_embeddings_model = OpenAIEmbeddings() |
|
|
|
self.embedder = CacheBackedEmbeddings.from_bytes_store( |
|
core_embeddings_model, |
|
store, |
|
namespace=core_embeddings_model.model |
|
) |
|
|
|
def upsert(self, texts, metadatas): |
|
"""""" |
|
ids = [str(uuid4()) for _ in range(len(texts))] |
|
embeds = self.embedder.embed_documents(texts) |
|
self.index.upsert(vectors=zip(ids, embeds, metadatas)) |
|
|
|
def index_documents(self, documents, batch_limit : int = INDEX_BATCH_LIMIT): |
|
"""""" |
|
texts = [] |
|
metadatas = [] |
|
|
|
|
|
for i in tqdm(range(len(documents))): |
|
|
|
|
|
for page in documents[i] : |
|
|
|
record_texts, record_metadatas = self.arxiv_loader.format_document(page) |
|
|
|
texts.extend(record_texts) |
|
metadatas.extend(record_metadatas) |
|
|
|
if len(texts) >= batch_limit: |
|
self.upsert(texts, metadatas) |
|
|
|
texts = [] |
|
metadatas = [] |
|
|
|
if len(texts) > 0: |
|
self.upsert(texts, metadatas) |
|
|
|
def get_vectorstore(self): |
|
"""""" |
|
return Pinecone(self.index, self.embedder.embed_query, "text") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
print("-------------- Loading Arxiv --------------") |
|
axloader = ArxivLoader() |
|
axloader.retrieve_urls() |
|
axloader.load_documents() |
|
|
|
print("\n-------------- Splitting sample doc --------------") |
|
sample_doc = axloader.documents[0] |
|
sample_page = sample_doc[0] |
|
|
|
splitter = CharacterTextSplitter() |
|
chunks = splitter.split(sample_page.page_content) |
|
print(len(chunks)) |
|
print(chunks[0]) |
|
|
|
print("\n-------------- testing pinecode indexer --------------") |
|
|
|
pi = PineconeIndexer() |
|
pi.load_embedder() |
|
pi.index_documents(axloader.documents) |
|
|
|
print(pi.index.describe_index_stats()) |
|
|