timeki's picture
update display and fix search only
d396732
raw
history blame
7.18 kB
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from ..engine.utils import num_tokens_from_string
from typing import List
from pydantic import Field
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
import pyalex
pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
def replace_nan_with_empty_dict(x):
return x if pd.notna(x) else {}
class OpenAlex():
def __init__(self):
pass
def search(self,keywords:str,n_results = 100,after = None,before = None):
if isinstance(keywords,str):
works = Works().search(keywords)
if after is not None:
assert isinstance(after,int), "after must be an integer"
assert after > 1900, "after must be greater than 1900"
works = works.filter(publication_year=f">{after}")
for page in works.paginate(per_page=n_results):
break
df_works = pd.DataFrame(page)
if df_works.empty:
return df_works
df_works = df_works.dropna(subset = ["title"])
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
df_works["url"] = df_works["id"]
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
df_works = df_works.drop(columns = ["abstract_inverted_index"])
df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "")
df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str)
return df_works
else:
raise Exception("Keywords must be a string")
def rerank(self,query,df,reranker):
scores = reranker.rank(
query,
df["content"].tolist()
)
scores = sorted(scores.results, key = lambda x : x.document.doc_id)
scores = [x.score for x in scores]
df["rerank_score"] = scores
return df
def make_network(self,df):
# Initialize your graph
G = nx.DiGraph()
for i,row in df.iterrows():
paper = row.to_dict()
G.add_node(paper['id'], **paper)
for reference in paper['referenced_works']:
if reference not in G:
pass
else:
# G.add_node(reference, id=reference, title="", reference_works=[], original=False)
G.add_edge(paper['id'], reference, relationship="CITING")
return G
def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):
net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
net.force_atlas_2based()
# Add nodes with size reflecting the PageRank to highlight importance
pagerank = nx.pagerank(G)
if color_by == "pagerank":
color_scores = pagerank
elif color_by == "rerank_score":
color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
else:
raise ValueError(f"Unknown color_by value: {color_by}")
# Normalize PageRank values to [0, 1] for color mapping
min_score = min(color_scores.values())
max_score = max(color_scores.values())
norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}
for node in G.nodes:
info = G.nodes[node]
title = info["title"]
label = title[:30] + " ..."
title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
title = "\n".join(title)
color_value = norm_color_scores[node]
# Generating a color from blue (low) to red (high)
color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
def clamp(x):
return int(max(0, min(x*255, 255)))
color = tuple([clamp(x) for x in color[:3]])
color = '#%02x%02x%02x' % color
net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)
# Add edges
for edge in G.edges:
net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")
# Show the network
if notebook:
return net.show("network.html")
else:
return net
def get_abstract_from_inverted_index(self,index):
if index is None:
return ""
else:
# Determine the maximum index to know the length of the reconstructed array
max_index = max([max(positions) for positions in index.values()])
# Initialize a list with placeholders for all positions
reconstructed = [''] * (max_index + 1)
# Iterate through the inverted index and place each token at its respective position(s)
for token, positions in index.items():
for position in positions:
reconstructed[position] = token
# Join the tokens to form the reconstructed sentence(s)
return ' '.join(reconstructed)
class OpenAlexRetriever(BaseRetriever):
min_year:int = 1960
max_year:int = None
k:int = 100
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
openalex = OpenAlex()
# Search for documents
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
docs = []
for i,row in df_docs.iterrows():
num_tokens = row["num_tokens"]
if num_tokens < 50 or num_tokens > 1000:
continue
doc = Document(
page_content = row["content"],
metadata = row.to_dict()
)
docs.append(doc)
return docs