import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import nltk 
from nltk.stem.porter import PorterStemmer
import json
import numpy as np 
import random
import streamlit as st

nltk.download('punkt')
   
def ExecuteQuery(query):

    class NeuralNet(nn.Module):

        def __init__(self,input_size,hidden_size,num_classes):
            super(NeuralNet,self).__init__()
            self.l1 = nn.Linear(input_size,hidden_size)
            self.l2 = nn.Linear(hidden_size,hidden_size)
            self.l3 = nn.Linear(hidden_size,num_classes)
            self.relu = nn.ReLU()

        def forward(self,x):
            out = self.l1(x)
            out = self.relu(out)
            out = self.l2(out)
            out = self.relu(out)
            out = self.l3(out)
            return out

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with open('files/intents.json', 'r') as json_data: 
        intents = json.load(json_data)

    FILE = "files/intents.pth"
    data = torch.load(FILE)
    # with open('Data/Tasks.pth') as f:
    #     data = torch.load(f)
    

    input_size = data["input_size"]
    hidden_size = data["hidden_size"]
    output_size = data["output_size"]
    all_words = data["all_words"]
    tags = data["tags"]
    model_state = data["model_state"]

    model = NeuralNet(input_size,hidden_size,output_size).to(device)
    model.load_state_dict(model_state)
    model.eval()

    Stemmer = PorterStemmer()

    def tokenize(sentence):
        return nltk.word_tokenize(sentence)

    def stem(word):
        return Stemmer.stem(word.lower())

    def bag_of_words(tokenized_sentence,words):
        sentence_word = [stem(word) for word in tokenized_sentence]
        bag = np.zeros(len(words),dtype=np.float32)

        for idx , w in enumerate(words):
            if w in sentence_word:
                bag[idx] = 1

        return bag

    sentence = str(query)

    sentence = tokenize(sentence)
    X = bag_of_words(sentence,all_words)
    X = X.reshape(1,X.shape[0])
    X = torch.from_numpy(X).to(device)

    output = model(X)

    _ , predicted = torch.max(output,dim=1)

    tag = tags[predicted.item()]

    probs = torch.softmax(output,dim=1)
    prob = probs[0][predicted.item()]

    if prob.item() >= 0.96:

        for intent in intents['intents']:

            if tag == intent["tag"]:

                reply = random.choice(intent["responses"])
                
                return reply, tag, prob.item()
            
    if prob.item() <= 0.95:
        reply = "opencosmo"
        tag = "opencosmo"
        return reply, tag, prob.item()




if query := st.text_input("Enter your query: "):
    reply = ExecuteQuery(query)
    st.write(reply[0])
    print(f"Tag: {reply[1]}")
    print(f"Prob: {reply[2]}")