GeorgiosIoannouCoder commited on
Commit
8a9d49c
·
verified ·
1 Parent(s): f3a1823

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +400 -0
  2. policies.py +55 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################################################################################################################
2
+ # Filename : app.py
3
+ # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
4
+ # Author : Georgios Ioannou
5
+ #
6
+ # TODO: Add code for Google Gemma 7b and 7b-it.
7
+ # TODO: Write code documentation.
8
+ # Copyright © 2024 by Georgios Ioannou
9
+ #############################################################################################################################
10
+ # Import libraries.
11
+
12
+ import os # Load environment variable(s).
13
+ import requests # Send HTTP GET request to Hugging Face models for inference.
14
+ import streamlit as st # Build the GUI of the application.
15
+ import streamlit.components.v1 as components
16
+
17
+ from dataclasses import dataclass
18
+ from dotenv import find_dotenv, load_dotenv # Read local .env file.
19
+ from langchain.callbacks import get_openai_callback
20
+ from langchain.chains import ConversationChain
21
+ from langchain.llms import OpenAI
22
+ from policies import complex_policy, simple_policy
23
+ from transformers import pipeline # Access to Hugging Face models.
24
+ from typing import Literal
25
+
26
+
27
+ #############################################################################################################################
28
+ # Load environment variable(s).
29
+
30
+ load_dotenv(find_dotenv()) # Read local .env file.
31
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
32
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
33
+
34
+
35
+ #############################################################################################################################
36
+
37
+
38
+ @dataclass
39
+ class Message:
40
+ """Class for keeping track of a chat message."""
41
+
42
+ origin: Literal["human", "ai"]
43
+ message: str
44
+
45
+
46
+ #############################################################################################################################
47
+ # Remeber (Keep history) user's message in each rendering of the application.
48
+
49
+
50
+ def initialize_session_state():
51
+ if "history" not in st.session_state:
52
+ st.session_state.history = []
53
+ if "token_count" not in st.session_state:
54
+ st.session_state.token_count = 0
55
+ if "policy" not in st.session_state:
56
+ st.session_state.policy = "No Policy"
57
+ if "is_gemma_model" not in st.session_state:
58
+ st.session_state.is_gemma_model = False
59
+ if "selected_ai_icon" not in st.session_state:
60
+ st.session_state.selected_ai_icon = "a1.png"
61
+ if "selected_user_icon" not in st.session_state:
62
+ st.session_state.selected_user_icon = "man.png"
63
+
64
+
65
+ #############################################################################################################################
66
+
67
+
68
+ def on_click_callback():
69
+ with get_openai_callback() as cb:
70
+ human_prompt = st.session_state.human_prompt
71
+
72
+ # Policy + User Query Text.
73
+
74
+ role = "User" # This can be an Agent.
75
+
76
+ task = f"""
77
+ Task: Check if there is unsafe content in
78
+ '{role}' messages in conversations
79
+ according our safety policy with the below categories.
80
+ """
81
+
82
+ output_format = f"""
83
+ Provide your safety assessment for '{role}' in the above conversation:
84
+ - First line must read 'safe' or 'unsafe'.
85
+ - If unsafe, a second line must include a comma-separated list of violated categories.
86
+ """
87
+
88
+ query = human_prompt
89
+
90
+ conversation = f"""
91
+ <BEGIN CONVERSATION>
92
+ User: {query}
93
+ <END CONVERSATION>
94
+ """
95
+
96
+ if st.session_state.policy == "Simple Policy":
97
+ prompt = f"""
98
+ {task}
99
+ {simple_policy}
100
+ {conversation}
101
+ {output_format}
102
+ """
103
+ elif st.session_state.policy == "Complex Policy":
104
+ prompt = f"""
105
+ {task}
106
+ {complex_policy}
107
+ {conversation}
108
+ {output_format}
109
+ """
110
+ elif st.session_state.policy == "No Policy":
111
+ prompt = human_prompt
112
+
113
+ # Getting the llm response for safety check 1.
114
+ # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
115
+ if st.session_state.is_gemma_model:
116
+ pass
117
+ else:
118
+ llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
119
+ st.session_state.history.append(Message("human", human_prompt))
120
+ st.session_state.token_count += cb.total_tokens
121
+
122
+ # Checking if response is safe. Safety Check 1. Checking what goes in (user input).
123
+ if (
124
+ "unsafe" in llm_response_safety_check_1.lower()
125
+ ): # If respone is unsafe return unsafe.
126
+ st.session_state.history.append(Message("ai", llm_response_safety_check_1))
127
+ return
128
+ else: # If respone is safe answer the question.
129
+ if st.session_state.is_gemma_model:
130
+ pass
131
+ else:
132
+ conversation_chain = ConversationChain(
133
+ llm=OpenAI(
134
+ temperature=0.2,
135
+ openai_api_key=OPENAI_API_KEY,
136
+ model_name=st.session_state.model,
137
+ ),
138
+ )
139
+ llm_response = conversation_chain.run(human_prompt)
140
+ # st.session_state.history.append(Message("ai", llm_response))
141
+ st.session_state.token_count += cb.total_tokens
142
+
143
+ # Policy + LLM Response.
144
+ query = llm_response
145
+
146
+ conversation = f"""
147
+ <BEGIN CONVERSATION>
148
+ User: {query}
149
+ <END CONVERSATION>
150
+ """
151
+
152
+ if st.session_state.policy == "Simple Policy":
153
+ prompt = f"""
154
+ {task}
155
+ {simple_policy}
156
+ {conversation}
157
+ {output_format}
158
+ """
159
+ elif st.session_state.policy == "Complex Policy":
160
+ prompt = f"""
161
+ {task}
162
+ {complex_policy}
163
+ {conversation}
164
+ {output_format}
165
+ """
166
+ elif st.session_state.policy == "No Policy":
167
+ prompt = llm_response
168
+
169
+ # Getting the llm response for safety check 2.
170
+ # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
171
+ if st.session_state.is_gemma_model:
172
+ pass
173
+ else:
174
+ llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
175
+ st.session_state.token_count += cb.total_tokens
176
+
177
+ # Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
178
+ if (
179
+ "unsafe" in llm_response_safety_check_2.lower()
180
+ ): # If respone is unsafe return.
181
+ st.session_state.history.append(
182
+ Message(
183
+ "ai",
184
+ "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!",
185
+ )
186
+ )
187
+ else:
188
+ st.session_state.history.append(Message("ai", llm_response))
189
+
190
+
191
+ #############################################################################################################################
192
+ # Function to apply local CSS.
193
+
194
+
195
+ def local_css(file_name):
196
+ with open(file_name) as f:
197
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
198
+
199
+
200
+ #############################################################################################################################
201
+
202
+
203
+ # Main function to create the Streamlit web application.
204
+
205
+
206
+ def main():
207
+ # try:
208
+ initialize_session_state()
209
+
210
+ # Page title and favicon.
211
+ st.set_page_config(page_title="Responsible AI", page_icon="⚖️")
212
+
213
+ # Load CSS.
214
+ local_css("./static/styles/styles.css")
215
+
216
+ # Title.
217
+ title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
218
+ Responsible AI</h1>"""
219
+ st.markdown(title, unsafe_allow_html=True)
220
+
221
+ # Subtitle 1.
222
+ title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
223
+ Showcase the importance of Responsible AI in LLMs</h3>"""
224
+ st.markdown(title, unsafe_allow_html=True)
225
+
226
+ # Subtitle 2.
227
+ title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
228
+ CUNY Tech Prep Tutorial 6</h2>"""
229
+ st.markdown(title, unsafe_allow_html=True)
230
+
231
+ # Image.
232
+ image = "./static/ctp.png"
233
+ left_co, cent_co, last_co = st.columns(3)
234
+ with cent_co:
235
+ st.image(image=image)
236
+
237
+ # Sidebar dropdown menu for Models.
238
+ models = [
239
+ "gpt-4-turbo",
240
+ "gpt-4",
241
+ "gpt-3.5-turbo",
242
+ "gpt-3.5-turbo-instruct",
243
+ "gemma-7b",
244
+ "gemma-7b-it",
245
+ ]
246
+ selected_model = st.sidebar.selectbox("Select Model:", models)
247
+ st.sidebar.write(f"Current Model: {selected_model}")
248
+
249
+ if selected_model == "gpt-4-turbo":
250
+ st.session_state.model = "gpt-4-turbo"
251
+ elif selected_model == "gpt-4":
252
+ st.session_state.model = "gpt-4"
253
+ elif selected_model == "gpt-3.5-turbo":
254
+ st.session_state.model = "gpt-3.5-turbo"
255
+ elif selected_model == "gpt-3.5-turbo-instruct":
256
+ st.session_state.model = "gpt-3.5-turbo-instruct"
257
+ elif selected_model == "gemma-7b":
258
+ st.session_state.model = "gemma-7b"
259
+ elif selected_model == "gemma-7b-it":
260
+ st.session_state.model = "gemma-7b-it"
261
+
262
+ if "gpt" in st.session_state.model:
263
+ st.session_state.conversation = ConversationChain(
264
+ llm=OpenAI(
265
+ temperature=0.2,
266
+ openai_api_key=OPENAI_API_KEY,
267
+ model_name=st.session_state.model,
268
+ ),
269
+ )
270
+ elif "gemma" in st.session_state.model:
271
+ # Load model from Hugging Face.
272
+ st.session_state.is_gemma_model = True
273
+ pass
274
+
275
+ # Sidebar dropdown menu for Policies.
276
+ policies = ["No Policy", "Complex Policy", "Simple Policy"]
277
+ selected_policy = st.sidebar.selectbox("Select Policy:", policies)
278
+ st.sidebar.write(f"Current Policy: {selected_policy}")
279
+
280
+ if selected_policy == "No Policy":
281
+ st.session_state.policy = "No Policy"
282
+ elif selected_policy == "Complex Policy":
283
+ st.session_state.policy = "Complex Policy"
284
+ elif selected_policy == "Simple Policy":
285
+ st.session_state.policy = "Simple Policy"
286
+
287
+ # Sidebar dropdown menu for AI Icons.
288
+ ai_icons = ["AI 1", "AI 2"]
289
+ selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
290
+ st.sidebar.write(f"Current AI Icon: {selected_ai_icon}")
291
+
292
+ if selected_ai_icon == "AI 1":
293
+ st.session_state.selected_ai_icon = "ai1.png"
294
+ elif selected_ai_icon == "AI 2":
295
+ st.session_state.selected_ai_icon = "ai2.png"
296
+
297
+ # Sidebar dropdown menu for User Icons.
298
+ user_icons = ["Man", "Woman"]
299
+ selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
300
+ st.sidebar.write(f"Current User Icon: {selected_user_icon}")
301
+
302
+ if selected_user_icon == "Man":
303
+ st.session_state.selected_user_icon = "man.png"
304
+ elif selected_user_icon == "Woman":
305
+ st.session_state.selected_user_icon = "woman.png"
306
+
307
+ # Placeholder for the chat messages.
308
+ chat_placeholder = st.container()
309
+ # Placeholder for the user input.
310
+ prompt_placeholder = st.form("chat-form")
311
+ token_placeholder = st.empty()
312
+
313
+ with chat_placeholder:
314
+ for chat in st.session_state.history:
315
+ div = f"""
316
+ <div class="chat-row
317
+ {'' if chat.origin == 'ai' else 'row-reverse'}">
318
+ <img class="chat-icon" src="app/static/{
319
+ st.session_state.selected_ai_icon if chat.origin == 'ai'
320
+ else st.session_state.selected_user_icon}"
321
+ width=32 height=32>
322
+ <div class="chat-bubble
323
+ {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
324
+ &#8203;{chat.message}
325
+ </div>
326
+ </div>
327
+ """
328
+ st.markdown(div, unsafe_allow_html=True)
329
+
330
+ for _ in range(3):
331
+ st.markdown("")
332
+
333
+ # User prompt.
334
+ with prompt_placeholder:
335
+ st.markdown("**Chat**")
336
+ cols = st.columns((6, 1))
337
+
338
+ # Large text input in the left column.
339
+ cols[0].text_input(
340
+ "Chat",
341
+ placeholder="What is your question?",
342
+ label_visibility="collapsed",
343
+ key="human_prompt",
344
+ )
345
+ # Red button in the right column.
346
+ cols[1].form_submit_button(
347
+ "Submit",
348
+ type="primary",
349
+ on_click=on_click_callback,
350
+ )
351
+
352
+ token_placeholder.caption(
353
+ f"""
354
+ Used {st.session_state.token_count} tokens \n
355
+ """
356
+ )
357
+
358
+ # GitHub repository of author.
359
+
360
+ st.markdown(
361
+ f"""
362
+ <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
363
+ <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
364
+ </p>
365
+ """,
366
+ unsafe_allow_html=True,
367
+ )
368
+
369
+ # Use the Enter key in the keyborad to click on the Submit button.
370
+ components.html(
371
+ """
372
+ <script>
373
+ const streamlitDoc = window.parent.document;
374
+
375
+ const buttons = Array.from(
376
+ streamlitDoc.querySelectorAll('.stButton > button')
377
+ );
378
+ const submitButton = buttons.find(
379
+ el => el.innerText === 'Submit'
380
+ );
381
+
382
+ streamlitDoc.addEventListener('keydown', function(e) {
383
+ switch (e.key) {
384
+ case 'Enter':
385
+ submitButton.click();
386
+ break;
387
+ }
388
+ });
389
+ </script>
390
+ """,
391
+ height=0,
392
+ width=0,
393
+ )
394
+
395
+
396
+ #############################################################################################################################
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
policies.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ complex_policy = """
2
+ <BEGIN UNSAFE CONTENT CATEGORIES>
3
+ O1: Violence and Hate.
4
+ Should not
5
+ - Help people plan or engage in violence.
6
+ - Encourage people to engage in violence.
7
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
8
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
9
+ Can
10
+ - Provide information on violence and discrimination.
11
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence.
12
+ O2: Sexual Content.
13
+ Should not
14
+ - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
15
+ Can
16
+ - Discuss topics related to sexuality and sexual education.
17
+ - Provide general information regarding sexual activities.
18
+ - Provide guidance on topics related to sexual health.
19
+ O3: Criminal Planning.
20
+ Should not
21
+ - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
22
+ - Encourage people to engage in criminal activities or condone criminal behavior.
23
+ Can
24
+ - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
25
+ - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
26
+ O4: Guns and Illegal Weapons.
27
+ Should not
28
+ - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
29
+ - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
30
+ Can
31
+ - Discuss firearms and the arguments for and against firearm ownership.
32
+ O5: Regulated or Controlled Substances.
33
+ Should not
34
+ - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
35
+ - Assist or encourage people to create such substances when it is illegal to do so.
36
+ Can
37
+ - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
38
+ - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
39
+ O6: Self-Harm.
40
+ Should not
41
+ - Encourage people to harm themselves.
42
+ - Romanticize or condone self-harm or suicide.
43
+ - Provide information regarding the methods of suicide or self-harm.
44
+ - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
45
+ Can
46
+ - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
47
+ Should
48
+ - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
49
+ <END UNSAFE CONTENT CATEGORIES>
50
+ """
51
+
52
+ simple_policy = """
53
+ Determine if this user query below the --- is safe or not. If not return 'Unsafe' with an explanation of why it is unsafe. If it is safe, return 'Safe'.
54
+ ---
55
+ """
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ langchain==0.0.316
2
+ openai==0.28
3
+ transformers