import spacy
import wikipediaapi
import wikipedia
from wikipedia.exceptions import DisambiguationError
from transformers import TFAutoModel, AutoTokenizer
import numpy as np
import pandas as pd
import faiss
import gradio as gr

try:
  nlp = spacy.load("en_core_web_sm")
except:
  spacy.cli.download("en_core_web_sm")
  nlp = spacy.load("en_core_web_sm")

wh_words = ['what', 'who', 'how', 'when', 'which']
def get_concepts(text):
  text = text.lower()
  doc = nlp(text)
  concepts = []
  for chunk in doc.noun_chunks:
    if chunk.text not in wh_words:
      concepts.append(chunk.text)
  return concepts

def get_passages(text, k=100):
    doc = nlp(text)
    passages = []
    passage_len = 0
    passage = ""
    sents = list(doc.sents)
    for i in range(len(sents)):
        sen = sents[i]
        passage_len+=len(sen)
        if passage_len >= k:
            passages.append(passage)
            passage = sen.text
            passage_len = len(sen)
            continue

        elif i==(len(sents)-1):
            passage+=" "+sen.text
            passages.append(passage)
            passage = ""
            passage_len = 0
            continue

        passage+=" "+sen.text
    return passages

def get_dicts_for_dpr(concepts, n_results=20, k=100):
  dicts = []
  for concept in concepts:
    wikis = wikipedia.search(concept, results=n_results)
    print(concept, "No of Wikis: ",len(wikis))
    for wiki in wikis:
        try:
          html_page = wikipedia.page(title = wiki, auto_suggest = False)
        except DisambiguationError:
          continue
        
        htmlResults=html_page.content
        
        passages = get_passages(htmlResults, k=k)
        for passage in passages:
          i_dicts = {}
          i_dicts['text'] = passage
          i_dicts['title'] = wiki
          dicts.append(i_dicts)
  return dicts

passage_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2")
query_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-2_H-128_A-2")
p_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2")
q_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-2_H-128_A-2")

def get_title_text_combined(passage_dicts):
    res = []
    for p in passage_dicts:
        res.append(tuple((p['title'], p['text'])))
    return res
    
def extracted_passage_embeddings(processed_passages, max_length=156):
    passage_inputs = p_tokenizer.batch_encode_plus(
                    processed_passages,
                    add_special_tokens=True,
                    truncation=True,
                    padding="max_length",
                    max_length=max_length,
                    return_token_type_ids=True
                )
    passage_embeddings = passage_encoder.predict([np.array(passage_inputs['input_ids']), 
                                                np.array(passage_inputs['attention_mask']), 
                                                np.array(passage_inputs['token_type_ids'])], 
                                                batch_size=64, 
                                                verbose=1)
    return passage_embeddings

def extracted_query_embeddings(queries, max_length=64):
    query_inputs = q_tokenizer.batch_encode_plus(
                    queries,
                    add_special_tokens=True,
                    truncation=True,
                    padding="max_length",
                    max_length=max_length,
                    return_token_type_ids=True
                )
    query_embeddings = query_encoder.predict([np.array(query_inputs['input_ids']), 
                                                np.array(query_inputs['attention_mask']), 
                                                np.array(query_inputs['token_type_ids'])], 
                                                batch_size=1, 
                                                verbose=1)
    return query_embeddings
    
#Wikipedia API:

def get_pagetext(page):
  s=str(page).replace("/t","")
  
  return s

def get_wiki_summary(search):
    wiki_wiki = wikipediaapi.Wikipedia('en')
    page = wiki_wiki.page(search)

    isExist = page.exists()
    if not isExist:
        return isExist, "Not found", "Not found", "Not found", "Not found"

    pageurl = page.fullurl
    pagetitle = page.title
    pagesummary = page.summary[0:60]
    pagetext = get_pagetext(page.text)

    backlinks = page.backlinks
    linklist = ""
    for link in backlinks.items():
      pui = link[0]
      linklist += pui + " ,  "
      a=1 
      
    categories = page.categories
    categorylist = ""
    for category in categories.items():
      pui = category[0]
      categorylist += pui + " ,  "
      a=1     
    
    links = page.links
    linklist2 = ""
    for link in links.items():
      pui = link[0]
      linklist2 += pui + " ,  "
      a=1 
      
    sections = page.sections
    
    ex_dic = {
      'Entity' : ["URL","Title","Summary", "Text", "Backlinks", "Links", "Categories"],
      'Value': [pageurl, pagetitle, pagesummary, pagetext, linklist,linklist2, categorylist ]
    }

    df = pd.DataFrame(ex_dic)
    
    return df
      
def search(question):
  concepts = get_concepts(question)
  print("concepts: ",concepts)
  dicts = get_dicts_for_dpr(concepts, n_results=1)
  lendicts = len(dicts)
  print("dicts len: ", lendicts)
  if lendicts == 0:
    return pd.DataFrame()
  processed_passages = get_title_text_combined(dicts)
  passage_embeddings = extracted_passage_embeddings(processed_passages)
  query_embeddings = extracted_query_embeddings([question])
  faiss_index = faiss.IndexFlatL2(128)
  faiss_index.add(passage_embeddings.pooler_output)
  prob, index = faiss_index.search(query_embeddings.pooler_output, k=lendicts)
  return pd.DataFrame([dicts[i] for i in index[0]])

# AI UI SOTA - Gradio blocks with UI formatting, and event driven UI
with gr.Blocks() as demo:     # Block documentation on event listeners, start here:  https://gradio.app/blocks_and_event_listeners/
  gr.Markdown("<h1><center>🍰 Ultimate Wikipedia AI 🎨</center></h1>")
  gr.Markdown("""<div align="center">Search and Find Anything Then Use in AI!  <a href="https://www.mediawiki.org/wiki/API:Main_page">MediaWiki - API for Wikipedia</a>.  <a href="https://paperswithcode.com/datasets?q=wikipedia&v=lst&o=newest">Papers,Code,Datasets for SOTA w/ Wikipedia</a>""")
  with gr.Row(): # inputs and buttons
    inp = gr.Textbox(lines=1, default="Syd Mead", label="Question")
  with gr.Row(): # inputs and buttons
    b3 = gr.Button("Search AI Summaries")    
    b4 = gr.Button("Search Web Live")
  with gr.Row(): # outputs DF1
    out = gr.Dataframe(label="Answers", type="pandas") 
  with gr.Row(): # output DF2
    out_DF = gr.Dataframe(wrap=True, max_rows=1000, overflow_row_behaviour= "paginate", datatype = ["markdown", "markdown"], headers=['Entity', 'Value'])
    inp.submit(fn=get_wiki_summary, inputs=inp, outputs=out_DF)
  b3.click(fn=search, inputs=inp, outputs=out)
  b4.click(fn=get_wiki_summary, inputs=inp, outputs=out_DF )
demo.launch(debug=True, show_error=True)