Spaces:
Runtime error
Runtime error
Yew Chong
commited on
Commit
•
bae2e43
1
Parent(s):
ca422fb
Update streamlit app with new LLM and prompts
Browse files- streamlit/app8.py +30 -101
streamlit/app8.py
CHANGED
@@ -17,7 +17,7 @@ import db_firestore as db
|
|
17 |
## ----------------------------------------------------------------
|
18 |
## LLM Part
|
19 |
import openai
|
20 |
-
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
21 |
import tiktoken
|
22 |
from langchain.prompts.few_shot import FewShotPromptTemplate
|
23 |
from langchain.prompts.prompt import PromptTemplate
|
@@ -31,7 +31,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings
|
|
31 |
from langchain_community.vectorstores import FAISS
|
32 |
|
33 |
from langchain.chains import LLMChain
|
34 |
-
from langchain.chains.conversation.memory import ConversationBufferMemory,
|
35 |
|
36 |
import os, dotenv
|
37 |
from dotenv import load_dotenv
|
@@ -114,8 +114,11 @@ if "embeddings" not in st.session_state:
|
|
114 |
encode_kwargs = encode_kwargs)
|
115 |
embeddings = st.session_state.embeddings
|
116 |
if "llm" not in st.session_state:
|
117 |
-
st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
118 |
llm = st.session_state.llm
|
|
|
|
|
|
|
119 |
if "llm_gpt4" not in st.session_state:
|
120 |
st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
121 |
llm_gpt4 = st.session_state.llm_gpt4
|
@@ -129,40 +132,13 @@ if "store" not in st.session_state:
|
|
129 |
st.session_state.store = db.get_store(index_name, embeddings=embeddings)
|
130 |
store = st.session_state.store
|
131 |
|
132 |
-
TEMPLATE = """You are a patient undergoing a medical check-up. You will be given the following:
|
133 |
-
1. A context to answer the doctor, for your possible symptoms.
|
134 |
-
2. A question about your current symptoms.
|
135 |
-
|
136 |
-
Your task is to answer the doctor's questions as simple as possible, acting like a patient.
|
137 |
-
Do not include other symptoms that are not included in the context, which provides your symptoms.
|
138 |
-
|
139 |
-
Answer the question to the point, without any elaboration if you're not prodded with it.
|
140 |
-
|
141 |
-
As you are a patient, you do not know any medical jargon or lingo. Do not include specific medical terms in your reply.
|
142 |
-
You only know colloquial words for medical terms.
|
143 |
-
For example, you should not reply with "dysarthria", but instead with "cannot speak properly".
|
144 |
-
For example, you should not reply with "syncope", but instead with "fainting".
|
145 |
-
|
146 |
-
Here is the context:
|
147 |
-
{context}
|
148 |
-
|
149 |
-
----------------------------------------------------------------
|
150 |
-
You are to reply the doctor's following question, with reference to the above context.
|
151 |
-
Question:
|
152 |
-
{question}
|
153 |
-
----------------------------------------------------------------
|
154 |
-
Remember, answer in a short and sweet manner, don't talk too much.
|
155 |
-
Your reply:
|
156 |
-
"""
|
157 |
-
|
158 |
-
with open('templates/patient.txt', 'r') as file:
|
159 |
-
TEMPLATE = file.read()
|
160 |
-
|
161 |
if "TEMPLATE" not in st.session_state:
|
|
|
|
|
162 |
st.session_state.TEMPLATE = TEMPLATE
|
163 |
|
164 |
with st.expander("Patient Prompt"):
|
165 |
-
TEMPLATE = st.text_area("Patient Prompt", value=TEMPLATE)
|
166 |
|
167 |
prompt = PromptTemplate(
|
168 |
input_variables = ["question", "context"],
|
@@ -177,7 +153,9 @@ def format_docs(docs):
|
|
177 |
|
178 |
|
179 |
if "memory" not in st.session_state:
|
180 |
-
st.session_state.memory =
|
|
|
|
|
181 |
memory = st.session_state.memory
|
182 |
|
183 |
|
@@ -200,74 +178,17 @@ sp_mapper = {"human":"student","ai":"patient"}
|
|
200 |
## Grader part
|
201 |
index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
|
202 |
|
203 |
-
# store = FAISS.load_local(index_name, embeddings)
|
204 |
-
|
205 |
if "store2" not in st.session_state:
|
206 |
st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
|
207 |
store2 = st.session_state.store2
|
208 |
|
209 |
-
TEMPLATE2 = """You are a teacher for medical students. You are grading a medical student on their OSCE, the Object Structured Clinical Examination.
|
210 |
-
|
211 |
-
Your task is to provide an overall assessment of a student's diagnosis, based on the rubrics provided.
|
212 |
-
You will be provided with the following information:
|
213 |
-
1. The rubrics that the student should be judged based upon.
|
214 |
-
2. The conversation history between the medical student and the patient.
|
215 |
-
3. The final diagnosis that the student will make.
|
216 |
-
|
217 |
-
=================================================================
|
218 |
-
|
219 |
-
Your task is as follows:
|
220 |
-
1. Your grading should touch on every part of the rubrics, and grade the student holistically.
|
221 |
-
Finally, provide an overall grade for the student.
|
222 |
-
|
223 |
-
Some additional information that is useful to understand the rubrics:
|
224 |
-
- The rubrics are segmented, with each area separated by dashes, such as "----------"
|
225 |
-
- There will be multiple segments on History Taking. For each segment, the rubrics and corresponding grades will be provided below the required history taking.
|
226 |
-
- For History Taking, you are to grade the student based on the rubrics, by checking the chat history between the patients and the medical student.
|
227 |
-
- There is an additional segment on Presentation, differentials, and diagnosis. The
|
228 |
-
|
229 |
-
|
230 |
-
=================================================================
|
231 |
-
|
232 |
-
e
|
233 |
-
Here are the rubrics for grading the student:
|
234 |
-
<rubrics>
|
235 |
-
|
236 |
-
{context}
|
237 |
-
|
238 |
-
</rubrics>
|
239 |
-
|
240 |
-
=================================================================
|
241 |
-
You are to give a comprehensive judgement based on the student's diagnosis, with reference to the above rubrics.
|
242 |
-
|
243 |
-
Here is the chat history between the medical student and the patient:
|
244 |
-
|
245 |
-
<history>
|
246 |
-
|
247 |
-
{history}
|
248 |
-
|
249 |
-
</history>
|
250 |
-
=================================================================
|
251 |
-
|
252 |
-
|
253 |
-
Student's final diagnosis:
|
254 |
-
<diagnosis>
|
255 |
-
{question}
|
256 |
-
</diagnosis>
|
257 |
-
|
258 |
-
=================================================================
|
259 |
-
|
260 |
-
Your grade:
|
261 |
-
"""
|
262 |
-
|
263 |
-
with open('templates/grader.txt', 'r') as file:
|
264 |
-
TEMPLATE2 = file.read()
|
265 |
-
|
266 |
if "TEMPLATE2" not in st.session_state:
|
|
|
|
|
267 |
st.session_state.TEMPLATE2 = TEMPLATE2
|
268 |
|
269 |
with st.expander("Grader Prompt"):
|
270 |
-
TEMPLATE2 = st.text_area("Grader Prompt", value=TEMPLATE2)
|
271 |
|
272 |
prompt2 = PromptTemplate(
|
273 |
input_variables = ["question", "context", "history"],
|
@@ -283,10 +204,6 @@ def format_docs(docs):
|
|
283 |
|
284 |
fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
|
285 |
|
286 |
-
if "memory2" not in st.session_state:
|
287 |
-
st.session_state.memory2 = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
|
288 |
-
memory2 = st.session_state.memory2
|
289 |
-
|
290 |
def x(_):
|
291 |
return fake_history
|
292 |
|
@@ -300,7 +217,19 @@ if ("chain2" not in st.session_state
|
|
300 |
"question": RunnablePassthrough(),
|
301 |
} |
|
302 |
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
)
|
305 |
chain2 = st.session_state.chain2
|
306 |
|
@@ -318,7 +247,7 @@ chain2 = st.session_state.chain2
|
|
318 |
if st.button("Clear History and Memory", type="primary"):
|
319 |
st.session_state.messages_1 = []
|
320 |
st.session_state.messages_2 = []
|
321 |
-
st.session_state.memory =
|
322 |
memory = st.session_state.memory
|
323 |
|
324 |
## Testing HTML
|
@@ -417,7 +346,7 @@ if text_prompt:
|
|
417 |
if st.session_state.active_chat==1:
|
418 |
full_response = chain.invoke(text_prompt).get("text")
|
419 |
else:
|
420 |
-
full_response = chain2.invoke(text_prompt).get("text")
|
421 |
message_placeholder.markdown(full_response)
|
422 |
messages.append({"role": "assistant", "content": full_response})
|
423 |
|
|
|
17 |
## ----------------------------------------------------------------
|
18 |
## LLM Part
|
19 |
import openai
|
20 |
+
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
|
21 |
import tiktoken
|
22 |
from langchain.prompts.few_shot import FewShotPromptTemplate
|
23 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
31 |
from langchain_community.vectorstores import FAISS
|
32 |
|
33 |
from langchain.chains import LLMChain
|
34 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
|
35 |
|
36 |
import os, dotenv
|
37 |
from dotenv import load_dotenv
|
|
|
114 |
encode_kwargs = encode_kwargs)
|
115 |
embeddings = st.session_state.embeddings
|
116 |
if "llm" not in st.session_state:
|
117 |
+
st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
|
118 |
llm = st.session_state.llm
|
119 |
+
if "llm_i" not in st.session_state:
|
120 |
+
st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
|
121 |
+
llm_i = st.session_state.llm_i
|
122 |
if "llm_gpt4" not in st.session_state:
|
123 |
st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
124 |
llm_gpt4 = st.session_state.llm_gpt4
|
|
|
132 |
st.session_state.store = db.get_store(index_name, embeddings=embeddings)
|
133 |
store = st.session_state.store
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
if "TEMPLATE" not in st.session_state:
|
136 |
+
with open('templates/patient.txt', 'r') as file:
|
137 |
+
TEMPLATE = file.read()
|
138 |
st.session_state.TEMPLATE = TEMPLATE
|
139 |
|
140 |
with st.expander("Patient Prompt"):
|
141 |
+
TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
|
142 |
|
143 |
prompt = PromptTemplate(
|
144 |
input_variables = ["question", "context"],
|
|
|
153 |
|
154 |
|
155 |
if "memory" not in st.session_state:
|
156 |
+
st.session_state.memory = ConversationBufferWindowMemory(
|
157 |
+
llm=llm, memory_key="chat_history", input_key="question",
|
158 |
+
k=5, human_prefix="student", ai_prefix="patient",)
|
159 |
memory = st.session_state.memory
|
160 |
|
161 |
|
|
|
178 |
## Grader part
|
179 |
index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
|
180 |
|
|
|
|
|
181 |
if "store2" not in st.session_state:
|
182 |
st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
|
183 |
store2 = st.session_state.store2
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
if "TEMPLATE2" not in st.session_state:
|
186 |
+
with open('templates/grader.txt', 'r') as file:
|
187 |
+
TEMPLATE2 = file.read()
|
188 |
st.session_state.TEMPLATE2 = TEMPLATE2
|
189 |
|
190 |
with st.expander("Grader Prompt"):
|
191 |
+
TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2)
|
192 |
|
193 |
prompt2 = PromptTemplate(
|
194 |
input_variables = ["question", "context", "history"],
|
|
|
204 |
|
205 |
fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
|
206 |
|
|
|
|
|
|
|
|
|
207 |
def x(_):
|
208 |
return fake_history
|
209 |
|
|
|
217 |
"question": RunnablePassthrough(),
|
218 |
} |
|
219 |
|
220 |
+
# LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
|
221 |
+
LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
|
222 |
+
| {
|
223 |
+
"json": itemgetter("text"),
|
224 |
+
"text": (
|
225 |
+
LLMChain(
|
226 |
+
llm=llm,
|
227 |
+
prompt=PromptTemplate(
|
228 |
+
input_variables=["text"],
|
229 |
+
template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
|
230 |
+
verbose=False)
|
231 |
+
)
|
232 |
+
}
|
233 |
)
|
234 |
chain2 = st.session_state.chain2
|
235 |
|
|
|
247 |
if st.button("Clear History and Memory", type="primary"):
|
248 |
st.session_state.messages_1 = []
|
249 |
st.session_state.messages_2 = []
|
250 |
+
st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" )
|
251 |
memory = st.session_state.memory
|
252 |
|
253 |
## Testing HTML
|
|
|
346 |
if st.session_state.active_chat==1:
|
347 |
full_response = chain.invoke(text_prompt).get("text")
|
348 |
else:
|
349 |
+
full_response = chain2.invoke(text_prompt).get("text").get("text")
|
350 |
message_placeholder.markdown(full_response)
|
351 |
messages.append({"role": "assistant", "content": full_response})
|
352 |
|