Spaces:
Runtime error
Runtime error
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import torch | |
import json | |
title = "AI ChatBot" | |
description = "A State-of-the-Art Large-scale Pretrained Response generation model (DialoGPT)" | |
examples = [["How are you?"]] | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Load courses data from JSON file | |
with open("uts_courses.json", "r") as f: | |
courses_data = json.load(f) | |
def predict(input_text, history=[]): | |
# Check if the input question is about courses | |
if "courses" in input_text.lower(): | |
# Check if the input question contains a specific field (e.g., Engineering, Information Technology, etc.) | |
for field in courses_data["courses"]: | |
if field.lower() in input_text.lower(): | |
# Get the list of courses for the specified field | |
courses_list = courses_data["courses"][field] | |
# Format the response | |
response = f"The available courses in {field} are: {', '.join(courses_list)}." | |
return response, history | |
# If the input question is not about courses, use the dialogue model to generate a response | |
# tokenize the new input sentence | |
new_user_input_ids = tokenizer.encode( | |
input_text + tokenizer.eos_token, return_tensors="pt" | |
).to(device) | |
# append the new user input tokens to the chat history | |
bot_input_ids = torch.cat([torch.tensor(history).to(device), new_user_input_ids], dim=-1) | |
# generate a response | |
history = model.generate( | |
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id | |
).tolist() | |
# convert the tokens to text, and then split the responses into lines | |
response = tokenizer.decode(history[0]).split() | |
return " ".join(response), history | |
def main(): | |
# Load courses data from JSON file | |
with open("uts_courses.json", "r") as f: | |
courses_data = json.load(f) | |
print("Contents of uts_courses.json:") | |
print(courses_data) | |
print() | |
if __name__ == "__main__": | |
main() | |
gr.Interface( | |
fn=predict, | |
title=title, | |
description=description, | |
examples=examples, | |
inputs=["text", "text"], # Changed input from "state" to "text" | |
outputs=["text", "state"], # Changed output to match the tuple return type | |
theme="finlaymacklon/boxy_violet" | |
).launch() | |