import re
import string
from collections import Counter
import math
from tqdm import tqdm
from itertools import combinations
from nltk.stem import PorterStemmer


# top 25 most common words in English and "wikipedia":
# https://en.wikipedia.org/wiki/Most_common_words_in_English
stop_words = set(['the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have',
                 'i', 'it', 'for', 'not', 'on', 'with', 'he', 'as', 'you',
                 'do', 'at', 'this', 'but', 'his', 'by', 'from', 'wikipedia'])
punct = re.compile(f'[{re.escape(string.punctuation)}]')

def tokenize(text):
    # Split text
    return(text.split())
    
def lowercase_filter(tokens):
    # Make text lowercase
    return([token.lower() for token in tokens])

def punctuation_filter(tokens):
    # Remove punctuation
    return([punct.sub('', token) for token in tokens])

def stopword_filter(tokens):
    # Remove stopwords
    return([token for token in tokens if token not in stop_words])

def stem_filter(tokens):
    # Stem words
    ps = PorterStemmer()
    return([ps.stem(token) for token in tokens])

def analyze(text):
    tokens = tokenize(text)
    tokens = lowercase_filter(tokens)
    tokens = punctuation_filter(tokens)
    tokens = stopword_filter(tokens)
    tokens = stem_filter(tokens)

    return([token for token in tokens if token])


# Setup an index and document structure to reference later
def index_documents(df):
    ind = {}
    doc = {}
    for i in tqdm(range(0, df.shape[0])):
        if df['ID'].iloc[i] not in doc:
            doc[df['ID'].iloc[i]] = df.iloc[i]
            full_text = ' '.join([df['title'].iloc[i], df['abstract'].iloc[i]])
        for token in analyze(full_text):
            if token not in ind:
                ind[token] = set()
            ind[token].add(df['ID'].iloc[i])
        if i % 5000 == 0:
            print(f'Indexed {i} documents', end='\r')
    df['title_abs'] = df['title'] + ' '  + df['abstract']
    print('Before all text')
    all_text = ' '.join(df['title_abs'])
    print('After all text')
    term_frequencies = Counter(analyze(all_text))
    return(ind, doc, term_frequencies)


def rank(termfreq, doc, ind, analyzed_query, documents):
    results = []
    if not documents:
        return results
    for document in documents:
        score = 0.0
        for token in analyzed_query:
            tf = termfreq.get(token, 0)
            if len(ind.get(token, set())) == 0:
                continue
            idf = math.log10(len(doc) / len(ind.get(token, set())))
            score += tf * idf
        results.append((document, score))
    return sorted(results, key=lambda doc: doc[1], reverse=True)



def search(tf, doc, ind, query, search_type='AND', ranking=False):
    """
    Search; this will return documents that contain words from the query,
    and rank them if requested (sets are fast, but unordered).

    Parameters:
        - tf: the term frequencies. Taken from indexing documents
        - doc: documents. Taken from indexing documents
        - ind: index. Taken from indexing documents
        - query: the query string
        - search_type: ('AND', 'OR') do all query terms have to match, or just one
        - score: (True, False) if True, rank results based on TF-IDF score
    """
    if search_type not in ('AND', 'OR'):
        return []

    analyzed_query = analyze(query)
    minus_query = [x[1:] for x in query.split() if x[0] == '-']
    minus_query = [q for mq in minus_query for q in analyze(mq)]
    
    specific_query = re.findall('"([^"]*)"', query)
    specific_query = ' '.join(specific_query)
    specific_query = [x.replace('"', '') for x in specific_query.split()]
    specific_query = [q for sq in specific_query for q in analyze(sq)]
    
    results = [ind.get(token, set()) for token in analyzed_query]
    minus_results = [ind.get(token, set()) for token in minus_query] 
    specific_results = [ind.get(token, set()) for token in specific_query]
    
    if len(minus_results) > 0:
        for j in range(0, len(results)):
            for i in range(0, len(minus_results)):
                results[j] = results[j] - minus_results[i]
    results = [r for r in results if len(r) > 0]

    if len(results) > 0:
        if search_type == 'AND':
            # Deal with users who use "" to get specific results
            if len(specific_results) > 0:
                documents = [doc[doc_id] for doc_id in set.intersection(*results)]
                if len(documents) == 0:
                    for x in range(len(results), 1, -1):
                        combo_len_list = []
                        all_combos = list(combinations(results, x))
                        for c in range(0, len(all_combos)):
                            combo_len_list.append(len(set.intersection(*all_combos[c], *specific_results)))
                        if len(combo_len_list) == 0:
                            continue
                        if max(combo_len_list) > 0:
                            break
                    if max(combo_len_list) > 0:
                        max_index = combo_len_list.index(max(combo_len_list))
                        documents = [doc[doc_id] for doc_id in set.intersection(*all_combos[max_index])]
            else:
                # all tokens must be in the document
                documents = [doc[doc_id] for doc_id in set.intersection(*results)]
                if len(documents) == 0: 
                    # Iterate from length of search query backwards until some documents are returned. 
                    # Looks at all combinations 
                    for x in range(len(results), 1, -1):
                        combo_len_list = []
                        all_combos = list(combinations(results, x))
                        for c in range(0, len(all_combos)):
                            combo_len_list.append(len(set.intersection(*all_combos[c])))
                        if len(combo_len_list) == 0:
                            continue
                        if max(combo_len_list) > 0:
                            break
                    max_index = combo_len_list.index(max(combo_len_list))
                    documents = [doc[doc_id] for doc_id in set.intersection(*all_combos[max_index])]
                    if len(documents) == 0:
                        documents = [doc[doc_id] for doc_id in set.union(*results)]
        if search_type == 'OR':
            # only one token has to be in the document
            documents = [doc[doc_id] for doc_id in set.union(*results)]

        if ranking:
            return(rank(tf, doc, ind, analyzed_query, documents))
    else:
        documents = []
    return documents