JustKiddo commited on
Commit
d1d88d6
1 Parent(s): 0097125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -26
app.py CHANGED
@@ -1,33 +1,83 @@
1
  import streamlit as st
2
  import requests
 
 
 
 
3
 
4
- st.set_page_config(
5
- page_title="IOGPT",
6
- page_icon="🤖",
7
- menu_items={} # This helps hide the menu
8
- )
9
-
10
- # Hide Streamlit menu and footer
11
- hide_menu_style = """
12
- <style>
13
- #MainMenu {visibility: hidden;}
14
- footer {visibility: hidden;}
15
- </style>
16
- """
17
- st.markdown(hide_menu_style, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class VietnameseChatbot:
20
  def __init__(self):
21
- self.api_key = st.secrets["GROQ_API_KEY"] # Store your API key in Huggingface Secrets
22
  self.api_url = "https://api.groq.com/openai/v1/chat/completions"
23
  self.headers = {
24
- "Content-Type": "application/json",
25
  "Authorization": f"Bearer {self.api_key}"
26
  }
27
-
 
 
28
  def get_response(self, user_query):
 
 
 
 
 
 
29
  try:
30
- # Add a system message to guide the model's response
31
  payload = {
32
  "model": "llama-3.2-3b-preview",
33
  "messages": [
@@ -38,19 +88,25 @@ class VietnameseChatbot:
38
  {"role": "user", "content": user_query}
39
  ]
40
  }
 
41
  response = requests.post(
42
- self.api_url, headers=self.headers, json=payload
 
 
43
  )
 
44
  if response.status_code == 200:
45
  return response.json()['choices'][0]['message']['content']
46
  else:
47
  print(f"API Error: {response.status_code}")
48
  print(f"Response: {response.text}")
49
  return "Đã xảy ra lỗi khi kết nối với API. Xin vui lòng thử lại."
 
50
  except Exception as e:
51
  print(f"Response generation error: {e}")
52
  return "Đã xảy ra lỗi. Xin vui lòng thử lại."
53
 
 
54
  @st.cache_resource
55
  def initialize_chatbot():
56
  return VietnameseChatbot()
@@ -61,32 +117,32 @@ def main():
61
 
62
  # Initialize chatbot using cached initialization
63
  chatbot = initialize_chatbot()
64
-
65
  # Chat history in session state
66
  if 'messages' not in st.session_state:
67
  st.session_state.messages = []
68
-
69
  # Display chat messages
70
  for message in st.session_state.messages:
71
  with st.chat_message(message["role"]):
72
  st.markdown(message["content"])
73
-
74
  # User input
75
  if prompt := st.chat_input("Hãy nói gì đó..."):
76
  # Add user message to chat history
77
  st.session_state.messages.append({"role": "user", "content": prompt})
78
-
79
  # Display user message
80
  with st.chat_message("user"):
81
  st.markdown(prompt)
82
-
83
  # Get chatbot response
84
  response = chatbot.get_response(prompt)
85
-
86
  # Display chatbot response
87
  with st.chat_message("assistant"):
88
  st.markdown(response)
89
-
90
  # Add assistant message to chat history
91
  st.session_state.messages.append({"role": "assistant", "content": response})
92
 
 
1
  import streamlit as st
2
  import requests
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ import faiss
7
 
8
+ class CompanyKnowledgeBase:
9
+ def __init__(self, dataset_name="JustKiddo/IODataset"):
10
+ # Load dataset from Hugging Face
11
+ try:
12
+ self.dataset = load_dataset(dataset_name)['train']
13
+
14
+ # Initialize semantic search
15
+ self.model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
16
+
17
+ # Prepare embeddings for all questions
18
+ self.embeddings = self.model.encode([
19
+ q for entry in self.dataset
20
+ for q in entry['questions']
21
+ ])
22
+
23
+ # Create FAISS index for efficient similarity search
24
+ self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
25
+ self.index.add(self.embeddings)
26
+
27
+ # Prepare a mapping of embeddings to answers
28
+ self.question_to_answer = {}
29
+ for entry in self.dataset:
30
+ for question in entry['questions']:
31
+ self.question_to_answer[question] = entry['answer']
32
+
33
+ except Exception as e:
34
+ st.error(f"Error loading knowledge base: {e}")
35
+ self.dataset = None
36
+ self.embeddings = None
37
+ self.index = None
38
+ self.question_to_answer = {}
39
+
40
+ def find_answer(self, query, threshold=0.8):
41
+ if not self.dataset:
42
+ return None
43
+
44
+ try:
45
+ # Embed the query
46
+ query_embedding = self.model.encode([query])
47
+
48
+ # Search for similar questions
49
+ D, I = self.index.search(query_embedding, 1)
50
+
51
+ # If similarity is high enough, return the corresponding answer
52
+ if D[0][0] < threshold:
53
+ # Find the matched question
54
+ matched_question = list(self.question_to_answer.keys())[I[0][0]]
55
+ return self.question_to_answer[matched_question]
56
+
57
+ except Exception as e:
58
+ st.error(f"Error in semantic search: {e}")
59
+
60
+ return None
61
 
62
  class VietnameseChatbot:
63
  def __init__(self):
64
+ self.api_key = st.secrets["GROQ_API_KEY"]
65
  self.api_url = "https://api.groq.com/openai/v1/chat/completions"
66
  self.headers = {
67
+ "Content-Type": "application/json",
68
  "Authorization": f"Bearer {self.api_key}"
69
  }
70
+ # Initialize company knowledge base
71
+ self.company_kb = CompanyKnowledgeBase()
72
+
73
  def get_response(self, user_query):
74
+ # First, check company knowledge base
75
+ company_answer = self.company_kb.find_answer(user_query)
76
+ if company_answer:
77
+ return company_answer
78
+
79
+ # If no company-specific answer, proceed with original API call
80
  try:
 
81
  payload = {
82
  "model": "llama-3.2-3b-preview",
83
  "messages": [
 
88
  {"role": "user", "content": user_query}
89
  ]
90
  }
91
+
92
  response = requests.post(
93
+ self.api_url,
94
+ headers=self.headers,
95
+ json=payload
96
  )
97
+
98
  if response.status_code == 200:
99
  return response.json()['choices'][0]['message']['content']
100
  else:
101
  print(f"API Error: {response.status_code}")
102
  print(f"Response: {response.text}")
103
  return "Đã xảy ra lỗi khi kết nối với API. Xin vui lòng thử lại."
104
+
105
  except Exception as e:
106
  print(f"Response generation error: {e}")
107
  return "Đã xảy ra lỗi. Xin vui lòng thử lại."
108
 
109
+ # Cached initialization of chatbot
110
  @st.cache_resource
111
  def initialize_chatbot():
112
  return VietnameseChatbot()
 
117
 
118
  # Initialize chatbot using cached initialization
119
  chatbot = initialize_chatbot()
120
+
121
  # Chat history in session state
122
  if 'messages' not in st.session_state:
123
  st.session_state.messages = []
124
+
125
  # Display chat messages
126
  for message in st.session_state.messages:
127
  with st.chat_message(message["role"]):
128
  st.markdown(message["content"])
129
+
130
  # User input
131
  if prompt := st.chat_input("Hãy nói gì đó..."):
132
  # Add user message to chat history
133
  st.session_state.messages.append({"role": "user", "content": prompt})
134
+
135
  # Display user message
136
  with st.chat_message("user"):
137
  st.markdown(prompt)
138
+
139
  # Get chatbot response
140
  response = chatbot.get_response(prompt)
141
+
142
  # Display chatbot response
143
  with st.chat_message("assistant"):
144
  st.markdown(response)
145
+
146
  # Add assistant message to chat history
147
  st.session_state.messages.append({"role": "assistant", "content": response})
148