import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from transformers import AutoModel
import json
from numpy.linalg import norm
import sqlite3
import urllib
from django.conf import settings
import Levenshtein

# this module act as a singleton class

class JinaAIEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def __call__(self, input: Documents) -> Embeddings:
        embeddings = self.model.encode(input)
        return embeddings.tolist()

# instance of embedding_model
embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
                                                trust_remote_code=True,
                                                cache_dir='models')

# instance of JinaAIEmbeddingFunction
ef = JinaAIEmbeddingFunction(embedding_model)

# list of topics
topic_descriptions = json.load(open("topic_descriptions.txt")) 
topics = list(dict.keys(topic_descriptions))
embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))

def lev_sim(a,b): return Levenshtein.distance(a,b)

def choose_topic(summary):
    embed = embedding_model.encode(summary)
    topic = ""
    max_sim = 0.
    for i,key in enumerate(topics):
        sim = cos_sim(embed,embeddings[i])
        if sim > max_sim:
            topic = key
            max_sim = sim
    return topic

def authors_list_to_str(authors):
   """input a list of authors, return a string represent authors"""
   text = ""
   for author in authors:
      text+=author+", "
   return text[:-3]

def authors_str_to_list(string):
    """input a string of authors, return a list of authors"""
    authors = []
    list_auth = string.split("and")
    for author in list_auth:
        if author != "et al.":
            authors.append(author.strip())
    return authors

def chunk_texts(text, max_char=400):
  """
  Chunk a long text into several chunks, with each chunk about 300-400 characters long,
  but make sure no word is cut in half.
  Args:
      text: The long text to be chunked.
      max_char: The maximum number of characters per chunk (default: 400).
  Returns:
      A list of chunks.
  """
  chunks = []
  current_chunk = ""
  words = text.split()
  for word in words:
    if len(current_chunk) + len(word) + 1 >= max_char:
        chunks.append(current_chunk)
        current_chunk = " "
    else:
      current_chunk += " " + word
  chunks.append(current_chunk.strip())
  return chunks

def trimming(txt):
    start = txt.find("{")
    end = txt.rfind("}")
    return txt[start:end+1].replace("\n"," ")

# crawl data

def extract_tag(txt,tagname):
    return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]

def get_record(extract):
    id = extract_tag(extract,"id")
    updated = extract_tag(extract,"updated")
    published = extract_tag(extract,"published")
    title = extract_tag(extract,"title").replace("\n ","").strip()
    summary = extract_tag(extract,"summary").replace("\n","").strip()
    authors = []
    while extract.find("<author>")!=-1:
        author = extract_tag(extract,"name")
        extract = extract[extract.find("</author>")+9:]
        authors.append(author)
    pattern = '<link title="pdf" href="'
    link_start = extract.find('<link title="pdf" href="')
    link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
    return [id, updated, published, title, authors, link, summary]

def crawl_exact_paper(title,author,max_results=3):
    authors = authors_list_to_str(author)
    records = []
    url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
    url = url.replace(" ","%20")
    try:
        arxiv_page = urllib.request.urlopen(url,timeout=100).read()
        xml = str(arxiv_page,encoding="utf-8") 
        while xml.find("<entry>") != -1:
            extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
            xml = xml[xml.find("</entry>")+8:]
            extract = get_record(extract)
            topic = choose_topic(extract[6])
            records.append([topic,*extract])
        return records
    except Exception as e:
        return "Error: "+str(e)

def crawl_arxiv(keyword_list, max_results=100):
    baseurl = 'http://export.arxiv.org/api/query?search_query='
    records = []
    for i,keyword in enumerate(keyword_list):
        if i ==0:
            url = baseurl + 'all:' + keyword
        else:
            url = url + '+OR+' + 'all:' + keyword
    url = url+ '&max_results=' + str(max_results)
    url = url.replace(' ', '%20')
    try:
        arxiv_page = urllib.request.urlopen(url,timeout=100).read()
        xml = str(arxiv_page,encoding="utf-8") 
        while xml.find("<entry>") != -1:
            extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
            xml = xml[xml.find("</entry>")+8:]
            extract = get_record(extract)
            topic = choose_topic(extract[6])
            records.append([topic,*extract])
        return records
    except Exception as e:
        return "Error: "+str(e)

# This class act as a module
class ArxivChroma:
    """
    Create an interface to arxivdb, which only support query and addition.
    This interface do not support edition and deletion procedures. 
    """
    client = None
    model = None
    collection = None
    
    @staticmethod
    def connect(table="arxiv_records", name="arxivdb/"):
        ArxivChroma.client = chromadb.PersistentClient(name)
        ArxivChroma.model = embedding_model
        ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table,
                                                                            embedding_function=JinaAIEmbeddingFunction(
                                                                            model = ArxivChroma.model
                                                                            ))

    @staticmethod
    def query_relevant(keywords, query_texts, n_results=3):
        """
        Perform a query using a list of keywords (str),
        or using a relavant string
        """
        contains = []
        for keyword in keywords:
            contains.append({"$contains":keyword.lower()})
        return ArxivChroma.collection.query(
            query_texts=query_texts,
            where_document={
                "$or":contains
            },
            n_results=n_results,
        )

    @staticmethod
    def query_exact(id):
        ids = ["{}_{}".format(id,j) for j in range(0,10)]
        return ArxivChroma.collection.get(ids=ids)

    @staticmethod
    def add(crawl_records):
        """
        Add crawl_records (list) obtained from arxiv_crawlers
        A record is a list of 8 columns: 
        [topic, id, updated, published, title, author, link, summary]
        Return the final length of the database table
        """
        for record in crawl_records:
                embed_text = """
                Topic: {},
                Title: {},
                Summary: {}
            """.format(record[0],record[4],record[7])
                chunks = chunk_texts(embed_text)
                ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
                paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
                ArxivChroma.collection.add(
                    documents = chunks,
                    metadatas=paper_ids,
                    ids = ids
                )
        return ArxivChroma.collection.count()

    @staticmethod
    def close_connection():
        pass

# This class act as a module
class ArxivSQL:
    table = "arxivsql"
    con = None
    cur = None

    @staticmethod
    def connect(name="db.sqlite3"):
        ArxivSQL.con = sqlite3.connect(name, check_same_thread=False)
        ArxivSQL.cur = ArxivSQL.con.cursor()

    @staticmethod
    def query(title="", author=[], threshold = 15):
        if len(author)>0:
            query_author= " OR ".join([f"author LIKE '%{a}%'" for a in author])
        else:
            query_author= "True"
        # Execute the query
        query = f"select * from {ArxivSQL.table} where {query_author}"
        results = ArxivSQL.cursor.execute(query).fetchall()
        if len(title) == 0:
            return results
        else:
            sim_score = {}
            for row in results:
                row_title = row[2]
                row_id = row[0]
                score = lev_sim(title, row_title)
                if score < threshold:
                    sim_score[row_id] = score
            sorted_results = sorted(sim_score.items(), key=lambda x: x[1])
            return ArxivSQL.query_id(sorted_results)

    @staticmethod     
    def query_id(ids=[]):
        try:
            if len(ids) == 0:
                return None
            query = "select * from {} where id in (".format(ArxivSQL.table)
            for id in ids:
                query+="'"+id+"',"
            query = query[:-1] + ")"
            result = ArxivSQL.cur.execute(query)
            return result.fetchall()
        except Exception as e:
            print(e)
            print("Error query: ",query)

    @staticmethod
    def add(crawl_records):
        """
        Add crawl_records (list) obtained from arxiv_crawlers
        A record is a list of 8 columns: 
        [topic, id, updated, published, title, author, link, summary]
        Return the final length of the database table
        """
        results = ""
        for record in crawl_records:
            try:
                query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
                    record[1][21:],
                    record[0],
                    record[4].replace('"',"'"),
                    authors_list_to_str(record[5]),
                    record[2][:10],
                    record[3][:10],
                    record[6]
                )
                ArxivSQL.cur.execute(query)
                ArxivSQL.con.commit()
            except Exception as e:
                results+=str(e)
                results+="\n" + query + "\n"
            finally:
                return results