import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network

from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from ..engine.utils import num_tokens_from_string

from typing import List
from pydantic import Field

from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
import pyalex

pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"


def replace_nan_with_empty_dict(x):
    return x if pd.notna(x) else {}

class OpenAlex():
    def __init__(self):
        pass


    def search(self,keywords:str,n_results = 100,after = None,before = None):

        if isinstance(keywords,str):
            works = Works().search(keywords)
            if after is not None:
                assert isinstance(after,int), "after must be an integer"
                assert after > 1900, "after must be greater than 1900"
                works = works.filter(publication_year=f">{after}")

            for page in works.paginate(per_page=n_results):
                break

            df_works = pd.DataFrame(page)
            df_works = df_works.dropna(subset = ["title"])
            df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
            df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
            df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
            df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
            df_works["url"] = df_works["id"]
            df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
            df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))

            df_works = df_works.drop(columns = ["abstract_inverted_index"])
            # df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
        
            return df_works
        else:
           raise Exception("Keywords must be a string")
    

    def rerank(self,query,df,reranker):
    
        scores = reranker.rank(
            query,
            df["content"].tolist(),
            top_k = len(df),
        )
        scores.sort(key = lambda x : x["corpus_id"])
        scores = [x["score"] for x in scores]
        df["rerank_score"] = scores
        return df


    def make_network(self,df):

        # Initialize your graph
        G = nx.DiGraph()

        for i,row in df.iterrows():
            paper = row.to_dict()
            G.add_node(paper['id'], **paper)
            for reference in paper['referenced_works']:
                if reference not in G:
                    pass
                else:
                    # G.add_node(reference, id=reference, title="", reference_works=[], original=False)
                    G.add_edge(paper['id'], reference, relationship="CITING")
        return G

    def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):

        net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
        net.force_atlas_2based()

        # Add nodes with size reflecting the PageRank to highlight importance
        pagerank = nx.pagerank(G)

        if color_by == "pagerank":
            color_scores = pagerank
        elif color_by == "rerank_score":
            color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
        else:
            raise ValueError(f"Unknown color_by value: {color_by}")

        # Normalize PageRank values to [0, 1] for color mapping
        min_score = min(color_scores.values())
        max_score = max(color_scores.values())
        norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}



        for node in G.nodes:
            info = G.nodes[node]
            title = info["title"]
            label = title[:30] + " ..."

            title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
            title = "\n".join(title)

            color_value = norm_color_scores[node]
            # Generating a color from blue (low) to red (high)
            color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
            def clamp(x): 
                return int(max(0, min(x*255, 255)))
            color = tuple([clamp(x) for x in color[:3]])
            color = '#%02x%02x%02x' % color
            
            net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)

        # Add edges
        for edge in G.edges:
            net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")

        # Show the network
        if notebook:
            return net.show("network.html")
        else:
            return net


    def get_abstract_from_inverted_index(self,index):

        if index is None:
            return ""
        else:

            # Determine the maximum index to know the length of the reconstructed array
            max_index = max([max(positions) for positions in index.values()])
            
            # Initialize a list with placeholders for all positions
            reconstructed = [''] * (max_index + 1)
            
            # Iterate through the inverted index and place each token at its respective position(s)
            for token, positions in index.items():
                for position in positions:
                    reconstructed[position] = token
            
            # Join the tokens to form the reconstructed sentence(s)
            return ' '.join(reconstructed)
        


class OpenAlexRetriever(BaseRetriever):
    min_year:int = 1960
    max_year:int = None
    k:int = 100

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        
        openalex = OpenAlex()

        # Search for documents
        df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)

        docs = []
        for i,row in df_docs.iterrows():
            num_tokens = row["num_tokens"]

            if num_tokens < 50 or num_tokens > 1000:
                continue

            doc = Document(
                page_content = row["content"],
                metadata = row.to_dict()
            )
            docs.append(doc)
        return docs