import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration,T5Model
import gradio as gr

def get_questions(paragraph, tokenizer, model, device):
    bt_levels = ['Remember', 'Understand', 'Apply', 'Analyse', 'Evaluate', 'Create']
    questions_dict = {}
    for bt_level in bt_levels:
        input_text = f'{bt_level}: {paragraph} {tokenizer.eos_token}'
        input_ids = tokenizer.encode(input_text, max_length=512, padding='max_length', truncation=True, return_tensors='pt').to(device)
        model.eval()
        generated_ids = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True).to(device)
        output_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True).lstrip('\n')
        output_text = output_text.split(' ', 1)[1]
        questions_dict.update({bt_level: output_text})
        # print(f'{bt_level} level question: {output_text}')
    return questions_dict
        

def main(paragraph):
    model = T5ForConditionalGeneration.from_pretrained('./save_model')
    tokenizer = T5Tokenizer.from_pretrained('./save_model')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    output = get_questions(paragraph, tokenizer, model, device)
    return output

gr.Interface(
    fn=main, 
    inputs="textbox",
    outputs="textbox",
    live=True).launch()