File size: 4,000 Bytes
f5a4d36
7677cff
f5a4d36
e458941
7677cff
e458941
f5a4d36
b4525ce
f5a4d36
 
e458941
7677cff
 
e458941
7677cff
e458941
7677cff
 
 
e458941
 
a6eee8d
 
 
 
 
 
 
f5a4d36
e458941
 
b4525ce
 
 
 
 
 
 
 
e458941
 
a6eee8d
 
 
 
e458941
a6eee8d
7677cff
 
 
 
 
 
a6eee8d
 
 
7677cff
f5a4d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4525ce
e458941
 
a6eee8d
e458941
a6eee8d
e458941
a6eee8d
 
 
 
 
 
 
 
e458941
 
 
 
 
a6eee8d
e458941
 
 
 
 
7677cff
 
 
 
 
 
e458941
 
a6eee8d
7677cff
a6eee8d
7677cff
 
 
 
 
 
 
 
 
 
 
e458941
7677cff
 
 
 
 
 
 
 
e458941
f5a4d36
 
 
 
 
e458941
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from typing import Optional
import requests
import uvicorn
from llm.basemodel import EHRModel
from llm.llm import VirtualNurseLLM
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from pythainlp.tokenize import sent_tokenize
from pydantic import BaseModel
from llm.models import model_list, get_model
import time

initial_model = "typhoon-v1.5x-70b-instruct"
nurse_llm = VirtualNurseLLM(
    # base_url=model_list[initial_model]["base_url"],
    model_name=model_list[initial_model]["model_name"],
    # api_key=model_list[initial_model]["api_key"]
)

# model: OpenThaiGPT
# nurse_llm = VirtualNurseLLM(
#     base_url="https://api.aieat.or.th/v1",
#     model=".",
#     api_key="dummy"
# )


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class UserInput(BaseModel):
    user_input: str
    model_name: str = "typhoon-v1.5x-70b-instruct"

class NurseResponse(BaseModel):
    nurse_response: str

class EHRData(BaseModel):
    ehr_data: Optional[EHRModel]
    current_context: Optional[str]
    current_prompt: Optional[str]
    current_prompt_ehr: Optional[str]
    current_patient_response: Optional[str]
    current_question: Optional[str]

class ChatHistory(BaseModel):
    chat_history: list
        
# @app.get("/", response_class=HTMLResponse)
# def read_index():
#     return """
#     <!DOCTYPE html>
#     <html>
#     <head>
#         <title>MALI_NURSE API</title>
#     </head>
#     <body>
#         <h1>Welcome to MALI_NURSE API</h1>
#         <p>This is the index page. Use the link below to access the API docs:</p>
#         <a href="/docs">Go to Swagger Docs UI</a>
#     </body>
#     </html>
#     """

@app.get("/history")
def get_chat_history():
    return ChatHistory(chat_history = nurse_llm.chat_history)

@app.get("/details")
def get_ehr_data():
    return EHRData(
        ehr_data=nurse_llm.ehr_data,
        current_context=nurse_llm.current_context,
        current_prompt=nurse_llm.current_prompt,
        current_prompt_ehr=nurse_llm.current_prompt_ehr,
        current_patient_response=nurse_llm.current_patient_response,
        current_question=nurse_llm.current_question
    )

def toggle_debug():
    nurse_llm.debug = not nurse_llm.debug
    return {"debug_mode": "on" if nurse_llm.debug else "off"}


@app.post("/reset")
def data_reset():
    nurse_llm.reset()
    print("Chat history and EHR data have been reset.")

model_cache = {}
def get_model_cached(model_name):
    if model_name not in model_cache:
        model_cache[model_name] = get_model(model_name=model_name)
    return model_cache[model_name]

@app.post("/nurse_response")
def nurse_response(user_input: UserInput):
    """
    Models: "typhoon-v1.5x-70b-instruct (default)", "openthaigpt", "llama-3.3-70b-versatile"
    """
    
    start_time = time.time()
    if user_input.model_name != nurse_llm.model_name:
        print(f"Changing model to {user_input.model_name}")
        try:
            nurse_llm.client = get_model_cached(model_name=user_input.model_name)
        except ValueError:
            return {"error": "Invalid model name"}
    print(nurse_llm.client)
    
    # response = nurse_llm.slim_invoke(user_input.user_input)
    response = nurse_llm.invoke(user_input.user_input)
    end_time = time.time()
    duration = end_time - start_time
    print(f"Function running time: {duration} seconds")
    # Log the model name, user input, response, and execution time in CSV format
    with open("runtime_log.csv", "a") as log_file:
        log_file.write(f"{user_input.model_name},{user_input.user_input},{response},{duration}\n")
    
    return NurseResponse(nurse_response=response)

# TTS
from tts.tts import app as tts_app

app.mount("/tts", tts_app)

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)