import pandas as pd import networkx as nx import matplotlib.pyplot as plt from pyvis.network import Network from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStoreRetriever from langchain_core.documents.base import Document from langchain_core.vectorstores import VectorStore from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from ..engine.utils import num_tokens_from_string from typing import List from pydantic import Field from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders import pyalex pyalex.config.email = "theo.alvesdacosta@ekimetrics.com" def replace_nan_with_empty_dict(x): return x if pd.notna(x) else {} class OpenAlex(): def __init__(self): pass def search(self,keywords:str,n_results = 100,after = None,before = None): if isinstance(keywords,str): works = Works().search(keywords) if after is not None: assert isinstance(after,int), "after must be an integer" assert after > 1900, "after must be greater than 1900" works = works.filter(publication_year=f">{after}") for page in works.paginate(per_page=n_results): break df_works = pd.DataFrame(page) if df_works.empty: return df_works df_works = df_works.dropna(subset = ["title"]) df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict) df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("") df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False)) df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None)) df_works["url"] = df_works["id"] df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip()) df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x)) df_works = df_works.drop(columns = ["abstract_inverted_index"]) df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "") df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str) return df_works else: raise Exception("Keywords must be a string") def rerank(self,query,df,reranker): scores = reranker.rank( query, df["content"].tolist() ) scores = sorted(scores.results, key = lambda x : x.document.doc_id) scores = [x.score for x in scores] df["rerank_score"] = scores return df def make_network(self,df): # Initialize your graph G = nx.DiGraph() for i,row in df.iterrows(): paper = row.to_dict() G.add_node(paper['id'], **paper) for reference in paper['referenced_works']: if reference not in G: pass else: # G.add_node(reference, id=reference, title="", reference_works=[], original=False) G.add_edge(paper['id'], reference, relationship="CITING") return G def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"): net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True) net.force_atlas_2based() # Add nodes with size reflecting the PageRank to highlight importance pagerank = nx.pagerank(G) if color_by == "pagerank": color_scores = pagerank elif color_by == "rerank_score": color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes} else: raise ValueError(f"Unknown color_by value: {color_by}") # Normalize PageRank values to [0, 1] for color mapping min_score = min(color_scores.values()) max_score = max(color_scores.values()) norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores} for node in G.nodes: info = G.nodes[node] title = info["title"] label = title[:30] + " ..." title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"] title = "\n".join(title) color_value = norm_color_scores[node] # Generating a color from blue (low) to red (high) color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red def clamp(x): return int(max(0, min(x*255, 255))) color = tuple([clamp(x) for x in color[:3]]) color = '#%02x%02x%02x' % color net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color) # Add edges for edge in G.edges: net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray") # Show the network if notebook: return net.show("network.html") else: return net def get_abstract_from_inverted_index(self,index): if index is None: return "" else: # Determine the maximum index to know the length of the reconstructed array max_index = max([max(positions) for positions in index.values()]) # Initialize a list with placeholders for all positions reconstructed = [''] * (max_index + 1) # Iterate through the inverted index and place each token at its respective position(s) for token, positions in index.items(): for position in positions: reconstructed[position] = token # Join the tokens to form the reconstructed sentence(s) return ' '.join(reconstructed) class OpenAlexRetriever(BaseRetriever): min_year:int = 1960 max_year:int = None k:int = 100 def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: openalex = OpenAlex() # Search for documents df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year) docs = [] for i,row in df_docs.iterrows(): num_tokens = row["num_tokens"] if num_tokens < 50 or num_tokens > 1000: continue doc = Document( page_content = row["content"], metadata = row.to_dict() ) docs.append(doc) return docs