from datasets import load_dataset
import numpy as np
from sklearn.svm import SVC
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
import nltk
from nltk.corpus import stopwords
from nltk import word_tokenize
from nltk import pos_tag
import pickle
import time
from nltk.corpus import names, gazetteers
from sklearn.model_selection import KFold
from itertools import chain
from sklearn.metrics import precision_score, recall_score, fbeta_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from string import punctuation


nltk.download('stopwords')
stopwords = stopwords.words('english')
PUNCT = list(punctuation)

nltk.download('gazetteers')
nltk.download('names')
from nltk.corpus import names, gazetteers

places=set(gazetteers.words())
people=set(names.words())
countries=set(gazetteers.words('countries.txt'))
nationalities=set(gazetteers.words('nationalities.txt'))

pos_tags = [ 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS',
                'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD',
                'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'
            ]

def feature_vector(word, scaled_position, current_word_pos_tag):
    features = []
    features.append(int(word.lower() in stopwords))
    features.append(int(word.isupper()))
    features.append(int(word in PUNCT))
    features.append(int(word.istitle()))
    features.append(int(word.isdigit()))
    # features.append(len(word))
    features.append(int(word in places))
    features.append(int(word in people))
    features.append(int(word in countries))
    features.append(int(word in nationalities))
    
    if (current_word_pos_tag==12) or (current_word_pos_tag==13): ##NNP, NNPS
        features.append(1)
    else:
        features.append(0)
    features.append(scaled_position)
    if 27 <= current_word_pos_tag <= 32: ##isVERB
        features.append(1)
    else:
        features.append(0)
    return np.asarray(features, dtype = np.float32)


def feature_vector2(word, prev_word_pos_tag, next_word_pos_tag, current_word_pos_tag):
    vec = np.zeros(9).astype('float32')
    if(word.istitle()):
        vec[0] = 1
    if word.lower() in stopwords:
        vec[1] = 1
    if(word.isupper()):
        vec[2] = 1
    vec[3] = len(word)
    vec[4] = word.isdigit()
    # idx : -11, 0...36
    # if prev_word_pos_tag!=-11:
    #   vec[5+prev_word_pos_tag] = 1

    # if next_word_pos_tag!=-11:
    #   vec[42+next_word_pos_tag] = 1

    # if current_word_pos_tag!=-11:
    #   vec[79+current_word_pos_tag] = 1

    vec[5] = 1 if word in places else 0
    vec[6] = 1 if word in people else 0
    vec[7] = 1 if word in countries else 0
    vec[8] = 1 if word in nationalities else 0
    return vec


# This function is used to make dataset with features and target label

def create_data(data):
    x_train = []
    y_train = []
    for x in data:
        for y in range(len(x['tokens'])):
            prev_pos = -1 if y==0 or x['pos_tags'][y-1]<10 else x['pos_tags'][y-1]
            next_pos = -1 if y==len(x['tokens'])-1 or x['pos_tags'][y+1]<10 else x['pos_tags'][y+1]
            current_pos = -1 if x['pos_tags'][y]<10 else x['pos_tags'][y]
            wordVec = feature_vector(x['tokens'][y], prev_pos-10, next_pos-10, current_pos-10)
            x_train.append(wordVec)
            y_train.append(1 if x['ner_tags'][y]!=0 else 0)
    return x_train, y_train

def evaluate_overall_metrics(predictions, folds):
    precision, recall, f0_5_score, f1_score, f2_score = 0, 0, 0, 0, 0

    for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
        # test_label_flat = list(chain.from_iterable(test_label))
        # y_pred_flat = list(chain.from_iterable(y_pred))

        # Calculate scores
        f0_5_score += fbeta_score(test_label_flat, y_pred_flat, beta=0.5, average='weighted')
        f1_score += fbeta_score(test_label_flat, y_pred_flat, beta=1, average='weighted')
        f2_score += fbeta_score(test_label_flat, y_pred_flat, beta=2, average='weighted')
        precision += precision_score(test_label_flat, y_pred_flat, average='weighted')
        recall += recall_score(test_label_flat, y_pred_flat, average='weighted')

    # Averaging across folds
    f0_5_score /= folds
    f1_score /= folds
    f2_score /= folds
    precision /= folds
    recall /= folds

    print(f'Overall Metrics:')
    print(f'Precision : {precision:.3f}')
    print(f'Recall : {recall:.3f}')
    print(f'F0.5 Score : {f0_5_score:.3f}')
    print(f'F1 Score : {f1_score:.3f}')
    print(f'F2 Score : {f2_score:.3f}\n')

def evaluate_per_pos_metrics(predictions, labels):
    combined_true = []
    combined_pred = []

    # Flatten the list of lists structure
    for test_label, y_pred in predictions:
        # for sentence_labels, sentence_preds in zip(test_label, y_pred):
        combined_true.extend(test_label)
        combined_pred.extend(y_pred)

    for tag in labels:
        true_binary = [1 if t == tag else 0 for t in combined_true]
        pred_binary = [1 if p == tag else 0 for p in combined_pred]

        # Calculate metrics for the tag
        precision = precision_score(true_binary, pred_binary, average='binary', zero_division=0)
        recall = recall_score(true_binary, pred_binary, average='binary', zero_division=0)
        f1_score = fbeta_score(true_binary, pred_binary, beta=1, average='binary', zero_division=0)

        print(f"Metrics for {tag}:")
        print(f'Precision : {precision:.3f}')
        print(f'Recall : {recall:.3f}')
        print(f'F1 Score : {f1_score:.3f}\n')

def plot_confusion_matrix(predictions, labels, folds):
    matrix = None
    for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
        # test_label_flat = list(chain.from_iterable(test_label))
        # y_pred_flat = list(chain.from_iterable(y_pred))

        # Compute confusion matrix for this fold
        cm = confusion_matrix(test_label_flat, y_pred_flat, labels=labels)
        if i == 0:
            matrix = cm
        else:
            matrix += cm

    matrix = matrix.astype('float')
    matrix = matrix / folds
    matrix = matrix / np.sum(matrix, axis=1, keepdims=True)  # Normalize

    plt.figure(figsize=(10, 8))
    sns.heatmap(matrix, annot=True, fmt=".2f", cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Normalized Confusion Matrix for NER')
    plt.show()

if __name__ == "__main__":
    data = load_dataset("conll2003", trust_remote_code=True)
    d_train = data['train']
    d_validation = data['validation']
    d_test = data['test']

    nltk.download('gazetteers')
    places=set(gazetteers.words())
    people=set(names.words())
    countries=set(gazetteers.words('countries.txt'))
    nationalities=set(gazetteers.words('nationalities.txt'))
    x_train, y_train = create_data(d_train)
    x_val, y_val = create_data(d_validation)
    x_test, y_test = create_data(d_test)
    all_X_train = np.concatenate((x_train, x_val, x_test))
    all_y_train = np.concatenate((y_train, y_val, y_test))

    #K-Fold
    num_fold = 5
    kf = KFold(n_splits=num_fold, random_state=42, shuffle=True)
    indices = np.arange(len(all_X_train))

    predictions = []
    all_models = []

    for i, (train_index, test_index) in enumerate(kf.split(indices)):
        print(f"Fold {i} Train Length: {len(train_index)} Test Length: {len(test_index)}")
        # all_folds.append((train_index, test_index))# Standardize the features such that all features contribute equally to the distance metric computation of the SVM
        X_train = all_X_train[train_index]
        y_train = all_y_train[train_index]

        X_test = all_X_train[test_index]
        y_test = all_y_train[test_index]

        # scaler = StandardScaler()
        # Fit only on the training data (i.e. compute mean and std)
        # X_train = scaler.fit_transform(X_train)

        # Use the train data fit values to scale val and test
        # X_train = scaler.transform(X_train)
        # X_val   = scaler.transform(X_val)
        # X_test  = scaler.transform(X_test)

        model = SVC(random_state = 42, verbose = True)
        model.fit(X_train, y_train)

        y_pred_val = model.predict(X_test)

        print("-------"*6)
        print(classification_report(y_true=y_test, y_pred=y_pred_val))
        print("-------"*6)
        
        pickle.dump(model, open(f"ner_svm_{str(i)}.pkl", 'wb'))

        predictions.append((y_test, y_pred_val))
        all_models.append(model)
        break


    FOLDS = 5
    labels = sorted(model.classes_)
    evaluate_overall_metrics(predictions, FOLDS)
    evaluate_per_pos_metrics(predictions, labels)
    plot_confusion_matrix(predictions, labels, FOLDS)