william4416 commited on
Commit
dcfc2ef
·
verified ·
1 Parent(s): 96cd07a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -7
app.py CHANGED
@@ -1,25 +1,45 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
  import torch
 
4
 
5
 
6
- title = "????AI ChatBot"
7
  description = "A State-of-the-Art Large-scale Pretrained Response generation model (DialoGPT)"
8
  examples = [["How are you?"]]
9
 
10
 
11
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
12
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
 
 
 
 
 
 
 
13
 
14
 
15
  def predict(input, history=[]):
 
 
 
 
 
 
 
 
 
 
 
 
16
  # tokenize the new input sentence
17
  new_user_input_ids = tokenizer.encode(
18
  input + tokenizer.eos_token, return_tensors="pt"
19
- )
20
 
21
  # append the new user input tokens to the chat history
22
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
23
 
24
  # generate a response
25
  history = model.generate(
@@ -27,15 +47,21 @@ def predict(input, history=[]):
27
  ).tolist()
28
 
29
  # convert the tokens to text, and then split the responses into lines
30
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
31
- # print('decoded_response-->>'+str(response))
32
  response = [
33
  (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
34
  ] # convert to tuples of list
35
- # print('response-->>'+str(response))
36
  return response, history
37
 
38
 
 
 
 
 
 
 
 
 
39
  gr.Interface(
40
  fn=predict,
41
  title=title,
@@ -44,4 +70,4 @@ gr.Interface(
44
  inputs=["text", "state"],
45
  outputs=["chatbot", "state"],
46
  theme="finlaymacklon/boxy_violet",
47
- ).launch()
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
  import torch
4
+ import json
5
 
6
 
7
+ title = "AI ChatBot"
8
  description = "A State-of-the-Art Large-scale Pretrained Response generation model (DialoGPT)"
9
  examples = [["How are you?"]]
10
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
13
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ model.to(device)
17
+
18
+ # Load courses data from JSON file
19
+ with open("uts_courses.json", "r") as f:
20
+ courses_data = json.load(f)
21
 
22
 
23
  def predict(input, history=[]):
24
+ # Check if the input question is about courses
25
+ if "courses" in input.lower():
26
+ # Check if the input question contains a specific field (e.g., Engineering, Information Technology, etc.)
27
+ for field in courses_data["courses"]:
28
+ if field.lower() in input.lower():
29
+ # Get the list of courses for the specified field
30
+ courses_list = courses_data["courses"][field]
31
+ # Format the response
32
+ response = f"The available courses in {field} are: {', '.join(courses_list)}."
33
+ return response, history
34
+
35
+ # If the input question is not about courses, use the dialogue model to generate a response
36
  # tokenize the new input sentence
37
  new_user_input_ids = tokenizer.encode(
38
  input + tokenizer.eos_token, return_tensors="pt"
39
+ ).to(device)
40
 
41
  # append the new user input tokens to the chat history
42
+ bot_input_ids = torch.cat([torch.tensor(history).to(device), new_user_input_ids], dim=-1)
43
 
44
  # generate a response
45
  history = model.generate(
 
47
  ).tolist()
48
 
49
  # convert the tokens to text, and then split the responses into lines
50
+ response = tokenizer.decode(history[0]).split("")
 
51
  response = [
52
  (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
53
  ] # convert to tuples of list
 
54
  return response, history
55
 
56
 
57
+ def main():
58
+ pass
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
63
+
64
+
65
  gr.Interface(
66
  fn=predict,
67
  title=title,
 
70
  inputs=["text", "state"],
71
  outputs=["chatbot", "state"],
72
  theme="finlaymacklon/boxy_violet",
73
+ ).launch()