GeorgiosIoannouCoder commited on
Commit
d5ba9de
1 Parent(s): 0f86120

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -177
app.py CHANGED
@@ -3,32 +3,29 @@
3
  # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
4
  # Author : Georgios Ioannou
5
  #
6
- # TODO: Add code for Google Gemma 7b and 7b-it.
7
  # Copyright © 2024 by Georgios Ioannou
8
  #############################################################################################################################
9
  # Import libraries.
10
-
11
- import os # Load environment variable(s).
12
- import requests # Send HTTP GET request to Hugging Face models for inference.
13
- import streamlit as st # Build the GUI of the application.
14
  import streamlit.components.v1 as components
15
 
16
  from dataclasses import dataclass
17
- from dotenv import find_dotenv, load_dotenv # Read local .env file.
 
18
  from langchain.callbacks import get_openai_callback
19
  from langchain.chains import ConversationChain
20
  from langchain.llms import OpenAI
21
  from policies import complex_policy, simple_policy
22
- from transformers import pipeline # Access to Hugging Face models.
23
  from typing import Literal
24
 
25
-
26
  #############################################################################################################################
27
  # Load environment variable(s).
28
 
29
- # HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
30
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
-
32
 
33
  #############################################################################################################################
34
 
@@ -42,7 +39,46 @@ class Message:
42
 
43
 
44
  #############################################################################################################################
45
- # Remeber (Keep history) user's message in each rendering of the application.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def initialize_session_state():
@@ -52,14 +88,21 @@ def initialize_session_state():
52
  st.session_state.token_count = 0
53
  if "policy" not in st.session_state:
54
  st.session_state.policy = "No Policy"
55
- if "is_gemma_model" not in st.session_state:
56
- st.session_state.is_gemma_model = False
57
  if "selected_ai_icon" not in st.session_state:
58
  st.session_state.selected_ai_icon = "a1.png"
59
  if "selected_user_icon" not in st.session_state:
60
  st.session_state.selected_user_icon = "man.png"
61
 
62
 
 
 
 
 
 
 
 
 
 
63
  #############################################################################################################################
64
 
65
 
@@ -68,9 +111,7 @@ def on_click_callback():
68
  human_prompt = st.session_state.human_prompt
69
 
70
  # Policy + User Query Text.
71
-
72
- role = "User" # This can be an Agent.
73
-
74
  task = f"""
75
  Task: Check if there is unsafe content in
76
  '{role}' messages in conversations
@@ -84,7 +125,6 @@ def on_click_callback():
84
  """
85
 
86
  query = human_prompt
87
-
88
  conversation = f"""
89
  <BEGIN CONVERSATION>
90
  User: {query}
@@ -92,55 +132,55 @@ def on_click_callback():
92
  """
93
 
94
  if st.session_state.policy == "Simple Policy":
95
- prompt = f"""
96
- {task}
97
- {simple_policy}
98
- {conversation}
99
- {output_format}
100
- """
101
  elif st.session_state.policy == "Complex Policy":
102
- prompt = f"""
103
- {task}
104
- {complex_policy}
105
- {conversation}
106
- {output_format}
107
- """
108
- elif st.session_state.policy == "No Policy":
109
  prompt = human_prompt
110
 
111
- # Getting the llm response for safety check 1.
112
- # "https://api-inference.huggingface.co/models/meta-llama/Llama-Guard-3-8B"
113
- if st.session_state.is_gemma_model:
114
- pass
115
- else:
116
  llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
117
- st.session_state.history.append(Message("human", human_prompt))
 
 
 
 
 
118
  st.session_state.token_count += cb.total_tokens
119
 
120
- # Checking if response is safe. Safety Check 1. Checking what goes in (user input).
121
- if (
122
- "unsafe" in llm_response_safety_check_1.lower()
123
- ): # If respone is unsafe return unsafe.
124
  st.session_state.history.append(Message("ai", llm_response_safety_check_1))
125
  return
126
- else: # If respone is safe answer the question.
127
- if st.session_state.is_gemma_model:
128
- pass
129
- else:
130
- conversation_chain = ConversationChain(
131
- llm=OpenAI(
132
- temperature=0.2,
133
- openai_api_key=OPENAI_API_KEY,
134
- model_name=st.session_state.model,
135
- ),
 
136
  )
137
- llm_response = conversation_chain.run(human_prompt)
138
- # st.session_state.history.append(Message("ai", llm_response))
139
- st.session_state.token_count += cb.total_tokens
 
 
 
 
 
 
140
 
141
- # Policy + LLM Response.
142
  query = llm_response
143
-
144
  conversation = f"""
145
  <BEGIN CONVERSATION>
146
  User: {query}
@@ -148,34 +188,26 @@ def on_click_callback():
148
  """
149
 
150
  if st.session_state.policy == "Simple Policy":
151
- prompt = f"""
152
- {task}
153
- {simple_policy}
154
- {conversation}
155
- {output_format}
156
- """
157
  elif st.session_state.policy == "Complex Policy":
158
- prompt = f"""
159
- {task}
160
- {complex_policy}
161
- {conversation}
162
- {output_format}
163
- """
164
- elif st.session_state.policy == "No Policy":
165
  prompt = llm_response
166
 
167
- # Getting the llm response for safety check 2.
168
- # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
169
- if st.session_state.is_gemma_model:
170
- pass
171
- else:
172
  llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
173
  st.session_state.token_count += cb.total_tokens
 
 
 
 
 
 
174
 
175
- # Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
176
- if (
177
- "unsafe" in llm_response_safety_check_2.lower()
178
- ): # If respone is unsafe return.
179
  st.session_state.history.append(
180
  Message(
181
  "ai",
@@ -187,22 +219,9 @@ def on_click_callback():
187
 
188
 
189
  #############################################################################################################################
190
- # Function to apply local CSS.
191
-
192
-
193
- def local_css(file_name):
194
- with open(file_name) as f:
195
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
196
-
197
-
198
- #############################################################################################################################
199
-
200
-
201
- # Main function to create the Streamlit web application.
202
 
203
 
204
  def main():
205
- # try:
206
  initialize_session_state()
207
 
208
  # Page title and favicon.
@@ -217,14 +236,14 @@ def main():
217
  st.markdown(title, unsafe_allow_html=True)
218
 
219
  # Subtitle 1.
220
- title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
221
  Showcase the importance of Responsible AI in LLMs Using Policies</h3>"""
222
- st.markdown(title, unsafe_allow_html=True)
223
 
224
  # Subtitle 2.
225
- title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
226
  CUNY Tech Prep Tutorial 6</h2>"""
227
- st.markdown(title, unsafe_allow_html=True)
228
 
229
  # Image.
230
  image = "./static/ctp.png"
@@ -234,30 +253,21 @@ def main():
234
 
235
  # Sidebar dropdown menu for Models.
236
  models = [
237
- "gpt-4-turbo",
238
- "gpt-4",
239
  "gpt-3.5-turbo",
240
  "gpt-3.5-turbo-instruct",
241
- "gemma-7b",
242
- "gemma-7b-it",
 
 
243
  ]
244
  selected_model = st.sidebar.selectbox("Select Model:", models)
245
- st.sidebar.markdown( f"<span style='color: white;'>Current Model: {selected_model}</span>", unsafe_allow_html=True )
246
-
247
- if selected_model == "gpt-4-turbo":
248
- st.session_state.model = "gpt-4-turbo"
249
- elif selected_model == "gpt-4":
250
- st.session_state.model = "gpt-4"
251
- elif selected_model == "gpt-3.5-turbo":
252
- st.session_state.model = "gpt-3.5-turbo"
253
- elif selected_model == "gpt-3.5-turbo-instruct":
254
- st.session_state.model = "gpt-3.5-turbo-instruct"
255
- elif selected_model == "gemma-7b":
256
- st.session_state.model = "gemma-7b"
257
- elif selected_model == "gemma-7b-it":
258
- st.session_state.model = "gemma-7b-it"
259
-
260
- if "gpt" in st.session_state.model:
261
  st.session_state.conversation = ConversationChain(
262
  llm=OpenAI(
263
  temperature=0.2,
@@ -265,27 +275,24 @@ def main():
265
  model_name=st.session_state.model,
266
  ),
267
  )
268
- elif "gemma" in st.session_state.model:
269
- # Load model from Hugging Face.
270
- st.session_state.is_gemma_model = True
271
- pass
272
 
273
  # Sidebar dropdown menu for Policies.
274
  policies = ["No Policy", "Complex Policy", "Simple Policy"]
275
  selected_policy = st.sidebar.selectbox("Select Policy:", policies)
276
- st.sidebar.markdown( f"<span style='color: white;'>Current Policy: {selected_policy}</span>", unsafe_allow_html=True )
 
 
 
277
 
278
- if selected_policy == "No Policy":
279
- st.session_state.policy = "No Policy"
280
- elif selected_policy == "Complex Policy":
281
- st.session_state.policy = "Complex Policy"
282
- elif selected_policy == "Simple Policy":
283
- st.session_state.policy = "Simple Policy"
284
 
285
  # Sidebar dropdown menu for AI Icons.
286
  ai_icons = ["AI 1", "AI 2"]
287
  selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
288
- st.sidebar.markdown( f"<span style='color: white;'>Current AI Icon: {selected_ai_icon}</span>", unsafe_allow_html=True )
 
 
 
289
 
290
  if selected_ai_icon == "AI 1":
291
  st.session_state.selected_ai_icon = "ai1.png"
@@ -295,33 +302,35 @@ def main():
295
  # Sidebar dropdown menu for User Icons.
296
  user_icons = ["Man", "Woman"]
297
  selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
298
- st.sidebar.markdown( f"<span style='color: white;'>Current User Icon: {selected_user_icon}</span>", unsafe_allow_html=True )
 
 
 
299
 
300
  if selected_user_icon == "Man":
301
  st.session_state.selected_user_icon = "man.png"
302
  elif selected_user_icon == "Woman":
303
  st.session_state.selected_user_icon = "woman.png"
304
 
305
- # Placeholder for the chat messages.
306
  chat_placeholder = st.container()
307
- # Placeholder for the user input.
308
  prompt_placeholder = st.form("chat-form")
309
  token_placeholder = st.empty()
310
 
311
  with chat_placeholder:
312
  for chat in st.session_state.history:
313
  div = f"""
314
- <div class="chat-row
315
- {'' if chat.origin == 'ai' else 'row-reverse'}">
316
- <img class="chat-icon" src="app/static/{
317
- st.session_state.selected_ai_icon if chat.origin == 'ai'
318
- else st.session_state.selected_user_icon}"
319
- width=32 height=32>
320
- <div class="chat-bubble
321
- {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
322
- &#8203;{chat.message}
323
- </div>
324
- </div>
325
  """
326
  st.markdown(div, unsafe_allow_html=True)
327
 
@@ -332,67 +341,54 @@ def main():
332
  with prompt_placeholder:
333
  st.markdown("**Chat**")
334
  cols = st.columns((6, 1))
335
-
336
- # Large text input in the left column.
337
  cols[0].text_input(
338
  "Chat",
339
  placeholder="What is your question?",
340
  label_visibility="collapsed",
341
  key="human_prompt",
342
  )
343
- # Red button in the right column.
344
  cols[1].form_submit_button(
345
  "Submit",
346
  type="primary",
347
  on_click=on_click_callback,
348
  )
349
 
350
- token_placeholder.caption(
351
- f"""
352
- Used {st.session_state.token_count} tokens \n
353
- """
354
- )
355
-
356
- # GitHub repository of author.
357
 
 
358
  st.markdown(
359
  f"""
360
- <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
361
- <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
362
- </p>
363
- """,
364
  unsafe_allow_html=True,
365
  )
366
 
367
- # Use the Enter key in the keyborad to click on the Submit button.
368
  components.html(
369
  """
370
- <script>
371
- const streamlitDoc = window.parent.document;
372
-
373
- const buttons = Array.from(
374
- streamlitDoc.querySelectorAll('.stButton > button')
375
- );
376
- const submitButton = buttons.find(
377
- el => el.innerText === 'Submit'
378
- );
379
-
380
- streamlitDoc.addEventListener('keydown', function(e) {
381
- switch (e.key) {
382
- case 'Enter':
383
- submitButton.click();
384
- break;
385
- }
386
- });
387
- </script>
388
- """,
389
  height=0,
390
  width=0,
391
  )
392
 
393
 
394
- #############################################################################################################################
395
-
396
-
397
  if __name__ == "__main__":
398
  main()
 
3
  # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
4
  # Author : Georgios Ioannou
5
  #
 
6
  # Copyright © 2024 by Georgios Ioannou
7
  #############################################################################################################################
8
  # Import libraries.
9
+ import os
10
+ import requests
11
+ import streamlit as st
 
12
  import streamlit.components.v1 as components
13
 
14
  from dataclasses import dataclass
15
+ from dotenv import find_dotenv, load_dotenv
16
+ from huggingface_hub import InferenceClient
17
  from langchain.callbacks import get_openai_callback
18
  from langchain.chains import ConversationChain
19
  from langchain.llms import OpenAI
20
  from policies import complex_policy, simple_policy
 
21
  from typing import Literal
22
 
 
23
  #############################################################################################################################
24
  # Load environment variable(s).
25
 
26
+ load_dotenv(find_dotenv())
27
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
29
 
30
  #############################################################################################################################
31
 
 
39
 
40
 
41
  #############################################################################################################################
42
+ # Initialize Hugging Face clients.
43
+
44
+
45
+ def initialize_hf_clients():
46
+ client = InferenceClient(api_key=HUGGINGFACE_API_KEY)
47
+ gpt2_api_url = "https://api-inference.huggingface.co/models/openai-community/gpt2"
48
+ headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
49
+
50
+ return client, gpt2_api_url, headers
51
+
52
+
53
+ #############################################################################################################################
54
+ # Hugging Face model inference functions.
55
+
56
+
57
+ def qwen_inference(prompt):
58
+ client, _, _ = initialize_hf_clients()
59
+ messages = [{"role": "user", "content": prompt}]
60
+
61
+ try:
62
+ response = client.chat.completions.create(
63
+ model="Qwen/Qwen2.5-1.5B-Instruct", messages=messages, max_tokens=500
64
+ )
65
+ return response.choices[0].message.content
66
+ except Exception as e:
67
+ return f"Error with Qwen inference: {str(e)}"
68
+
69
+
70
+ def gpt2_inference(prompt):
71
+ _, api_url, headers = initialize_hf_clients()
72
+
73
+ try:
74
+ response = requests.post(api_url, headers=headers, json={"inputs": prompt})
75
+ return response.json()[0]["generated_text"]
76
+ except Exception as e:
77
+ return f"Error with GPT-2 inference: {str(e)}"
78
+
79
+
80
+ #############################################################################################################################
81
+ # Initialize session state.
82
 
83
 
84
  def initialize_session_state():
 
88
  st.session_state.token_count = 0
89
  if "policy" not in st.session_state:
90
  st.session_state.policy = "No Policy"
 
 
91
  if "selected_ai_icon" not in st.session_state:
92
  st.session_state.selected_ai_icon = "a1.png"
93
  if "selected_user_icon" not in st.session_state:
94
  st.session_state.selected_user_icon = "man.png"
95
 
96
 
97
+ #############################################################################################################################
98
+ # Function to apply local CSS.
99
+
100
+
101
+ def local_css(file_name):
102
+ with open(file_name) as f:
103
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
104
+
105
+
106
  #############################################################################################################################
107
 
108
 
 
111
  human_prompt = st.session_state.human_prompt
112
 
113
  # Policy + User Query Text.
114
+ role = "User"
 
 
115
  task = f"""
116
  Task: Check if there is unsafe content in
117
  '{role}' messages in conversations
 
125
  """
126
 
127
  query = human_prompt
 
128
  conversation = f"""
129
  <BEGIN CONVERSATION>
130
  User: {query}
 
132
  """
133
 
134
  if st.session_state.policy == "Simple Policy":
135
+ prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}"
 
 
 
 
 
136
  elif st.session_state.policy == "Complex Policy":
137
+ prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}"
138
+ else:
 
 
 
 
 
139
  prompt = human_prompt
140
 
141
+ # Safety check 1 - Input check.
142
+ if (
143
+ "gpt" in st.session_state.model.lower()
144
+ and "gpt2" not in st.session_state.model.lower()
145
+ ):
146
  llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
147
+ st.session_state.token_count += cb.total_tokens
148
+ elif "qwen" in st.session_state.model.lower():
149
+ llm_response_safety_check_1 = qwen_inference(prompt)
150
+ st.session_state.token_count += cb.total_tokens
151
+ else: # gpt2.
152
+ llm_response_safety_check_1 = gpt2_inference(prompt)
153
  st.session_state.token_count += cb.total_tokens
154
 
155
+ st.session_state.history.append(Message("human", human_prompt))
156
+
157
+ if "unsafe" in llm_response_safety_check_1.lower():
 
158
  st.session_state.history.append(Message("ai", llm_response_safety_check_1))
159
  return
160
+
161
+ # Get model response.
162
+ if (
163
+ "gpt" in st.session_state.model.lower()
164
+ and "gpt2" not in st.session_state.model.lower()
165
+ ):
166
+ conversation_chain = ConversationChain(
167
+ llm=OpenAI(
168
+ temperature=0.2,
169
+ openai_api_key=OPENAI_API_KEY,
170
+ model_name=st.session_state.model,
171
  )
172
+ )
173
+ llm_response = conversation_chain.run(human_prompt)
174
+ st.session_state.token_count += cb.total_tokens
175
+ elif "qwen" in st.session_state.model.lower():
176
+ llm_response = qwen_inference(human_prompt)
177
+ st.session_state.token_count += cb.total_tokens
178
+ else: # gpt2.
179
+ llm_response = gpt2_inference(human_prompt)
180
+ st.session_state.token_count += cb.total_tokens
181
 
182
+ # Safety check 2 - Output check.
183
  query = llm_response
 
184
  conversation = f"""
185
  <BEGIN CONVERSATION>
186
  User: {query}
 
188
  """
189
 
190
  if st.session_state.policy == "Simple Policy":
191
+ prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}"
 
 
 
 
 
192
  elif st.session_state.policy == "Complex Policy":
193
+ prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}"
194
+ else:
 
 
 
 
 
195
  prompt = llm_response
196
 
197
+ if (
198
+ "gpt" in st.session_state.model.lower()
199
+ and "gpt2" not in st.session_state.model.lower()
200
+ ):
 
201
  llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
202
  st.session_state.token_count += cb.total_tokens
203
+ elif "qwen" in st.session_state.model.lower():
204
+ llm_response_safety_check_2 = qwen_inference(prompt)
205
+ st.session_state.token_count += cb.total_tokens
206
+ else: # gpt2.
207
+ llm_response_safety_check_2 = gpt2_inference(prompt)
208
+ st.session_state.token_count += cb.total_tokens
209
 
210
+ if "unsafe" in llm_response_safety_check_2.lower():
 
 
 
211
  st.session_state.history.append(
212
  Message(
213
  "ai",
 
219
 
220
 
221
  #############################################################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  def main():
 
225
  initialize_session_state()
226
 
227
  # Page title and favicon.
 
236
  st.markdown(title, unsafe_allow_html=True)
237
 
238
  # Subtitle 1.
239
+ subtitle1 = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
240
  Showcase the importance of Responsible AI in LLMs Using Policies</h3>"""
241
+ st.markdown(subtitle1, unsafe_allow_html=True)
242
 
243
  # Subtitle 2.
244
+ subtitle2 = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
245
  CUNY Tech Prep Tutorial 6</h2>"""
246
+ st.markdown(subtitle2, unsafe_allow_html=True)
247
 
248
  # Image.
249
  image = "./static/ctp.png"
 
253
 
254
  # Sidebar dropdown menu for Models.
255
  models = [
 
 
256
  "gpt-3.5-turbo",
257
  "gpt-3.5-turbo-instruct",
258
+ "gpt-4-turbo",
259
+ "gpt-4",
260
+ "Qwen2.5-1.5B-Instruct",
261
+ "gpt2",
262
  ]
263
  selected_model = st.sidebar.selectbox("Select Model:", models)
264
+ st.sidebar.markdown(
265
+ f"<span style='color: white;'>Current Model: {selected_model}</span>",
266
+ unsafe_allow_html=True,
267
+ )
268
+
269
+ st.session_state.model = selected_model
270
+ if "gpt" in selected_model.lower() and "gpt2" not in selected_model.lower():
 
 
 
 
 
 
 
 
 
271
  st.session_state.conversation = ConversationChain(
272
  llm=OpenAI(
273
  temperature=0.2,
 
275
  model_name=st.session_state.model,
276
  ),
277
  )
 
 
 
 
278
 
279
  # Sidebar dropdown menu for Policies.
280
  policies = ["No Policy", "Complex Policy", "Simple Policy"]
281
  selected_policy = st.sidebar.selectbox("Select Policy:", policies)
282
+ st.sidebar.markdown(
283
+ f"<span style='color: white;'>Current Policy: {selected_policy}</span>",
284
+ unsafe_allow_html=True,
285
+ )
286
 
287
+ st.session_state.policy = selected_policy
 
 
 
 
 
288
 
289
  # Sidebar dropdown menu for AI Icons.
290
  ai_icons = ["AI 1", "AI 2"]
291
  selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
292
+ st.sidebar.markdown(
293
+ f"<span style='color: white;'>Current AI Icon: {selected_ai_icon}</span>",
294
+ unsafe_allow_html=True,
295
+ )
296
 
297
  if selected_ai_icon == "AI 1":
298
  st.session_state.selected_ai_icon = "ai1.png"
 
302
  # Sidebar dropdown menu for User Icons.
303
  user_icons = ["Man", "Woman"]
304
  selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
305
+ st.sidebar.markdown(
306
+ f"<span style='color: white;'>Current User Icon: {selected_user_icon}</span>",
307
+ unsafe_allow_html=True,
308
+ )
309
 
310
  if selected_user_icon == "Man":
311
  st.session_state.selected_user_icon = "man.png"
312
  elif selected_user_icon == "Woman":
313
  st.session_state.selected_user_icon = "woman.png"
314
 
315
+ # Chat interface.
316
  chat_placeholder = st.container()
 
317
  prompt_placeholder = st.form("chat-form")
318
  token_placeholder = st.empty()
319
 
320
  with chat_placeholder:
321
  for chat in st.session_state.history:
322
  div = f"""
323
+ <div class="chat-row
324
+ {'' if chat.origin == 'ai' else 'row-reverse'}">
325
+ <img class="chat-icon" src="app/static/{
326
+ st.session_state.selected_ai_icon if chat.origin == 'ai'
327
+ else st.session_state.selected_user_icon}"
328
+ width=32 height=32>
329
+ <div class="chat-bubble
330
+ {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
331
+ &#8203;{chat.message}
332
+ </div>
333
+ </div>
334
  """
335
  st.markdown(div, unsafe_allow_html=True)
336
 
 
341
  with prompt_placeholder:
342
  st.markdown("**Chat**")
343
  cols = st.columns((6, 1))
 
 
344
  cols[0].text_input(
345
  "Chat",
346
  placeholder="What is your question?",
347
  label_visibility="collapsed",
348
  key="human_prompt",
349
  )
 
350
  cols[1].form_submit_button(
351
  "Submit",
352
  type="primary",
353
  on_click=on_click_callback,
354
  )
355
 
356
+ token_placeholder.caption(f"Used {st.session_state.token_count} tokens\n")
 
 
 
 
 
 
357
 
358
+ # GitHub repository link.
359
  st.markdown(
360
  f"""
361
+ <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
362
+ <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
363
+ </p>
364
+ """,
365
  unsafe_allow_html=True,
366
  )
367
 
368
+ # Enter key handler.
369
  components.html(
370
  """
371
+ <script>
372
+ const streamlitDoc = window.parent.document;
373
+ const buttons = Array.from(
374
+ streamlitDoc.querySelectorAll('.stButton > button')
375
+ );
376
+ const submitButton = buttons.find(
377
+ el => el.innerText === 'Submit'
378
+ );
379
+ streamlitDoc.addEventListener('keydown', function(e) {
380
+ switch (e.key) {
381
+ case 'Enter':
382
+ submitButton.click();
383
+ break;
384
+ }
385
+ });
386
+ </script>
387
+ """,
 
 
388
  height=0,
389
  width=0,
390
  )
391
 
392
 
 
 
 
393
  if __name__ == "__main__":
394
  main()