microhum commited on
Commit
a6eee8d
·
1 Parent(s): 6fed246

update: model choosing / more parameter data

Browse files
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
  .venv
2
  .env
3
- __pycache__
 
1
  .venv
2
  .env
3
+ **/__pycache__/
DockerFile CHANGED
@@ -10,4 +10,4 @@ COPY --chown=user ./requirements.txt requirements.txt
10
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
 
12
  COPY --chown=user . /app
13
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
10
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
 
12
  COPY --chown=user . /app
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -4,4 +4,5 @@ emoji: 🐳
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: docker
 
7
  ---
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: docker
7
+ app_port: 7860
8
  ---
__pycache__/main.cpython-311.pyc DELETED
Binary file (2.97 kB)
 
cli.py CHANGED
@@ -1,21 +1,27 @@
1
  from llm.client import NurseCLI
2
  from llm.llm import VirtualNurseLLM
 
 
3
 
4
- # model: typhoon-v1.5x-70b-instruct
5
- # nurse_llm = VirtualNurseLLM(
6
- # base_url="https://api.opentyphoon.ai/v1",
7
- # model="typhoon-v1.5x-70b-instruct",
8
- # api_key=os.getenv("TYPHOON_API_KEY")
9
- # )
10
-
11
- # model: OpenThaiGPT
12
 
13
  if __name__ == "__main__":
14
- nurse_llm = VirtualNurseLLM(
15
- base_url="https://api.aieat.or.th/v1",
16
- model=".",
17
- api_key="dummy"
18
- )
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  cli = NurseCLI(nurse_llm)
21
  cli.start()
 
1
  from llm.client import NurseCLI
2
  from llm.llm import VirtualNurseLLM
3
+ from dotenv import load_dotenv
4
+ import os
5
 
6
+ load_dotenv()
 
 
 
 
 
 
 
7
 
8
  if __name__ == "__main__":
9
+ model_choice = input("Choose the model to use (1 for typhoon-v1.5x-70b-instruct, 2 for OpenThaiGPT): ")
10
+ if model_choice == "1":
11
+ nurse_llm = VirtualNurseLLM(
12
+ base_url="https://api.opentyphoon.ai/v1",
13
+ model="typhoon-v1.5x-70b-instruct",
14
+ api_key=os.environ.get("TYPHOON_CHAT_KEY")
15
+ )
16
+ elif model_choice == "2":
17
+ nurse_llm = VirtualNurseLLM(
18
+ base_url="https://api.aieat.or.th/v1",
19
+ model="OpenThaiGPT",
20
+ api_key="dummy"
21
+ )
22
+ else:
23
+ print("Invalid choice. Exiting.")
24
+ exit(1)
25
 
26
  cli = NurseCLI(nurse_llm)
27
  cli.start()
llm/__pycache__/llm.cpython-311.pyc CHANGED
Binary files a/llm/__pycache__/llm.cpython-311.pyc and b/llm/__pycache__/llm.cpython-311.pyc differ
 
llm/llm.py CHANGED
@@ -19,8 +19,11 @@ class VirtualNurseLLM:
19
  self.JSON_EXAMPLE = JSON_EXAMPLE
20
  self.ehr_data = {}
21
  self.chat_history = []
 
 
22
  self.debug = False
23
  self.current_prompt = None
 
24
  self.current_question = None
25
 
26
  def create_prompt(self, task_type):
@@ -41,7 +44,7 @@ class VirtualNurseLLM:
41
  def gather_ehr(self, patient_response, max_retries=3):
42
  prompt = self.create_prompt("extract_ehr")
43
  messages = prompt.format_messages(ehr_data=self.ehr_data, patient_response=patient_response, example=self.JSON_EXAMPLE)
44
- self.current_prompt = messages
45
  response = self.client(messages=messages)
46
  if self.debug:
47
  pprint(f"gather ehr llm response: \n{response.content}\n")
@@ -52,10 +55,10 @@ class VirtualNurseLLM:
52
  json_content = self.extract_json_content(response.content)
53
  if self.debug:
54
  pprint(f"JSON after dumps:\n{json_content}\n")
55
- ehr_data = EHRModel.parse_raw(json_content)
56
 
57
  # Update only missing parameters
58
- for key, value in ehr_data.dict().items():
59
  if value not in [None, [], {}]: # Checks for None and empty lists or dicts
60
  print(f"Updating {key} with value {value}")
61
  self.ehr_data[key] = value
@@ -79,7 +82,7 @@ class VirtualNurseLLM:
79
  retry_prompt=retry_prompt,
80
  json_problem=json_content
81
  )
82
- self.current_prompt = messages
83
  print(f"กำลังลองใหม่ด้วย prompt ที่ปรับแล้ว: {retry_prompt}")
84
  response = self.client(messages=messages)
85
 
@@ -109,7 +112,7 @@ class VirtualNurseLLM:
109
  messages = ChatPromptTemplate.from_messages([question_prompt, history_context])
110
  messages = messages.format_messages(description=f'"{field}":"{description}"', context=context, patient_response=patient_response, field_descriptions=self.field_descriptions)
111
  self.current_context = context
112
- self.current_prompt = messages
113
 
114
  # format print
115
  # pprint(pformat(messages.messages[0].prompt.template, indent=4, width=80))
@@ -127,6 +130,7 @@ class VirtualNurseLLM:
127
  if patient_response:
128
  self.chat_history.append({"role": "user", "content": patient_response})
129
  question = self.get_question(patient_response)
 
130
  self.chat_history.append({"role": "assistant", "content": question})
131
  return question
132
 
 
19
  self.JSON_EXAMPLE = JSON_EXAMPLE
20
  self.ehr_data = {}
21
  self.chat_history = []
22
+ self.current_patient_response = None
23
+ self.current_context = None
24
  self.debug = False
25
  self.current_prompt = None
26
+ self.current_prompt_ehr = None
27
  self.current_question = None
28
 
29
  def create_prompt(self, task_type):
 
44
  def gather_ehr(self, patient_response, max_retries=3):
45
  prompt = self.create_prompt("extract_ehr")
46
  messages = prompt.format_messages(ehr_data=self.ehr_data, patient_response=patient_response, example=self.JSON_EXAMPLE)
47
+ self.current_prompt_ehr = messages[0].content
48
  response = self.client(messages=messages)
49
  if self.debug:
50
  pprint(f"gather ehr llm response: \n{response.content}\n")
 
55
  json_content = self.extract_json_content(response.content)
56
  if self.debug:
57
  pprint(f"JSON after dumps:\n{json_content}\n")
58
+ ehr_data = EHRModel.model_validate_json(json_content)
59
 
60
  # Update only missing parameters
61
+ for key, value in ehr_data.model_dump().items():
62
  if value not in [None, [], {}]: # Checks for None and empty lists or dicts
63
  print(f"Updating {key} with value {value}")
64
  self.ehr_data[key] = value
 
82
  retry_prompt=retry_prompt,
83
  json_problem=json_content
84
  )
85
+ self.current_prompt_ehr = messages[0].content
86
  print(f"กำลังลองใหม่ด้วย prompt ที่ปรับแล้ว: {retry_prompt}")
87
  response = self.client(messages=messages)
88
 
 
112
  messages = ChatPromptTemplate.from_messages([question_prompt, history_context])
113
  messages = messages.format_messages(description=f'"{field}":"{description}"', context=context, patient_response=patient_response, field_descriptions=self.field_descriptions)
114
  self.current_context = context
115
+ self.current_prompt = messages[0].content
116
 
117
  # format print
118
  # pprint(pformat(messages.messages[0].prompt.template, indent=4, width=80))
 
130
  if patient_response:
131
  self.chat_history.append({"role": "user", "content": patient_response})
132
  question = self.get_question(patient_response)
133
+ self.current_patient_response = patient_response
134
  self.chat_history.append({"role": "assistant", "content": question})
135
  return question
136
 
app.py → main.py RENAMED
@@ -5,23 +5,23 @@ from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import HTMLResponse
6
  from pydantic import BaseModel
7
  import os
8
- import dotenv
9
- dotenv.load_dotenv()
10
 
11
  # model: typhoon-v1.5x-70b-instruct
12
- # nurse_llm = VirtualNurseLLM(
13
- # base_url="https://api.opentyphoon.ai/v1",
14
- # model="typhoon-v1.5x-70b-instruct",
15
- # api_key=os.getenv("TYPHOON_API_KEY")
16
- # )
17
-
18
- # model: OpenThaiGPT
19
  nurse_llm = VirtualNurseLLM(
20
- base_url="https://api.aieat.or.th/v1",
21
- model=".",
22
- api_key="dummy"
23
  )
24
 
 
 
 
 
 
 
 
25
  app = FastAPI()
26
 
27
  app.add_middleware(
@@ -34,7 +34,22 @@ app.add_middleware(
34
 
35
  class UserInput(BaseModel):
36
  user_input: str
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  @app.get("/", response_class=HTMLResponse)
39
  def read_index():
40
  return """
@@ -53,21 +68,24 @@ def read_index():
53
 
54
  @app.get("/history")
55
  def get_chat_history():
56
- return {"chat_history": nurse_llm.chat_history}
57
 
58
- @app.get("/ehr")
59
  def get_ehr_data():
60
- return {"ehr_data": nurse_llm.ehr_data}
61
-
62
- @app.get("/status")
63
- def get_status():
64
- return {"current_prompt": nurse_llm.current_prompt}
 
 
 
65
 
66
- @app.post("/debug")
67
  def toggle_debug():
68
  nurse_llm.debug = not nurse_llm.debug
69
  return {"debug_mode": "on" if nurse_llm.debug else "off"}
70
 
 
71
  @app.post("/reset")
72
  def data_reset():
73
  nurse_llm.reset()
@@ -75,8 +93,17 @@ def data_reset():
75
 
76
  @app.post("/nurse_response")
77
  def nurse_response(user_input: UserInput):
 
 
 
 
 
 
 
 
 
78
  response = nurse_llm.invoke(user_input.user_input)
79
- return {"nurse_response": response}
80
 
81
  if __name__ == "__main__":
82
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
5
  from fastapi.responses import HTMLResponse
6
  from pydantic import BaseModel
7
  import os
8
+ from dotenv import load_dotenv
9
+ load_dotenv()
10
 
11
  # model: typhoon-v1.5x-70b-instruct
 
 
 
 
 
 
 
12
  nurse_llm = VirtualNurseLLM(
13
+ base_url="https://api.opentyphoon.ai/v1",
14
+ model="typhoon-v1.5x-70b-instruct",
15
+ api_key=os.getenv("TYPHOON_CHAT_KEY")
16
  )
17
 
18
+ # model: OpenThaiGPT
19
+ # nurse_llm = VirtualNurseLLM(
20
+ # base_url="https://api.aieat.or.th/v1",
21
+ # model=".",
22
+ # api_key="dummy"
23
+ # )
24
+
25
  app = FastAPI()
26
 
27
  app.add_middleware(
 
34
 
35
  class UserInput(BaseModel):
36
  user_input: str
37
+ model_name: str = "typhoon-v1.5x-70b-instruct"
38
+
39
+ class NurseResponse(BaseModel):
40
+ nurse_response: str
41
 
42
+ class EHRData(BaseModel):
43
+ ehr_data: dict
44
+ current_context: str
45
+ current_prompt: str
46
+ current_prompt_ehr: str
47
+ current_patient_response: str
48
+ current_question: str
49
+
50
+ class ChatHistory(BaseModel):
51
+ chat_history: list
52
+
53
  @app.get("/", response_class=HTMLResponse)
54
  def read_index():
55
  return """
 
68
 
69
  @app.get("/history")
70
  def get_chat_history():
71
+ return ChatHistory(chat_history = nurse_llm.chat_history)
72
 
73
+ @app.get("/details")
74
  def get_ehr_data():
75
+ return EHRData(
76
+ ehr_data=nurse_llm.ehr_data,
77
+ current_context=nurse_llm.current_context,
78
+ current_prompt=nurse_llm.current_prompt,
79
+ current_prompt_ehr=nurse_llm.current_prompt_ehr,
80
+ current_patient_response=nurse_llm.current_patient_response,
81
+ current_question=nurse_llm.current_question
82
+ )
83
 
 
84
  def toggle_debug():
85
  nurse_llm.debug = not nurse_llm.debug
86
  return {"debug_mode": "on" if nurse_llm.debug else "off"}
87
 
88
+
89
  @app.post("/reset")
90
  def data_reset():
91
  nurse_llm.reset()
 
93
 
94
  @app.post("/nurse_response")
95
  def nurse_response(user_input: UserInput):
96
+ """
97
+ Models: "typhoon-v1.5x-70b-instruct (default)", "openthaigpt"
98
+ """
99
+ if user_input.model_name == "typhoon-v1.5x-70b-instruct":
100
+ nurse_llm.model = "typhoon-v1.5x-70b-instruct"
101
+ elif user_input.model_name == "openthaigpt":
102
+ nurse_llm.model = "openthaigpt"
103
+ else:
104
+ return {"error": "Invalid model name"}
105
  response = nurse_llm.invoke(user_input.user_input)
106
+ return NurseResponse(nurse_response = response)
107
 
108
  if __name__ == "__main__":
109
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)