Spaces:
Sleeping
Sleeping
File size: 4,000 Bytes
f5a4d36 7677cff f5a4d36 e458941 7677cff e458941 f5a4d36 b4525ce f5a4d36 e458941 7677cff e458941 7677cff e458941 7677cff e458941 a6eee8d f5a4d36 e458941 b4525ce e458941 a6eee8d e458941 a6eee8d 7677cff a6eee8d 7677cff f5a4d36 b4525ce e458941 a6eee8d e458941 a6eee8d e458941 a6eee8d e458941 a6eee8d e458941 7677cff e458941 a6eee8d 7677cff a6eee8d 7677cff e458941 7677cff e458941 f5a4d36 e458941 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from typing import Optional
import requests
import uvicorn
from llm.basemodel import EHRModel
from llm.llm import VirtualNurseLLM
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from pythainlp.tokenize import sent_tokenize
from pydantic import BaseModel
from llm.models import model_list, get_model
import time
initial_model = "typhoon-v1.5x-70b-instruct"
nurse_llm = VirtualNurseLLM(
# base_url=model_list[initial_model]["base_url"],
model_name=model_list[initial_model]["model_name"],
# api_key=model_list[initial_model]["api_key"]
)
# model: OpenThaiGPT
# nurse_llm = VirtualNurseLLM(
# base_url="https://api.aieat.or.th/v1",
# model=".",
# api_key="dummy"
# )
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class UserInput(BaseModel):
user_input: str
model_name: str = "typhoon-v1.5x-70b-instruct"
class NurseResponse(BaseModel):
nurse_response: str
class EHRData(BaseModel):
ehr_data: Optional[EHRModel]
current_context: Optional[str]
current_prompt: Optional[str]
current_prompt_ehr: Optional[str]
current_patient_response: Optional[str]
current_question: Optional[str]
class ChatHistory(BaseModel):
chat_history: list
# @app.get("/", response_class=HTMLResponse)
# def read_index():
# return """
# <!DOCTYPE html>
# <html>
# <head>
# <title>MALI_NURSE API</title>
# </head>
# <body>
# <h1>Welcome to MALI_NURSE API</h1>
# <p>This is the index page. Use the link below to access the API docs:</p>
# <a href="/docs">Go to Swagger Docs UI</a>
# </body>
# </html>
# """
@app.get("/history")
def get_chat_history():
return ChatHistory(chat_history = nurse_llm.chat_history)
@app.get("/details")
def get_ehr_data():
return EHRData(
ehr_data=nurse_llm.ehr_data,
current_context=nurse_llm.current_context,
current_prompt=nurse_llm.current_prompt,
current_prompt_ehr=nurse_llm.current_prompt_ehr,
current_patient_response=nurse_llm.current_patient_response,
current_question=nurse_llm.current_question
)
def toggle_debug():
nurse_llm.debug = not nurse_llm.debug
return {"debug_mode": "on" if nurse_llm.debug else "off"}
@app.post("/reset")
def data_reset():
nurse_llm.reset()
print("Chat history and EHR data have been reset.")
model_cache = {}
def get_model_cached(model_name):
if model_name not in model_cache:
model_cache[model_name] = get_model(model_name=model_name)
return model_cache[model_name]
@app.post("/nurse_response")
def nurse_response(user_input: UserInput):
"""
Models: "typhoon-v1.5x-70b-instruct (default)", "openthaigpt", "llama-3.3-70b-versatile"
"""
start_time = time.time()
if user_input.model_name != nurse_llm.model_name:
print(f"Changing model to {user_input.model_name}")
try:
nurse_llm.client = get_model_cached(model_name=user_input.model_name)
except ValueError:
return {"error": "Invalid model name"}
print(nurse_llm.client)
# response = nurse_llm.slim_invoke(user_input.user_input)
response = nurse_llm.invoke(user_input.user_input)
end_time = time.time()
duration = end_time - start_time
print(f"Function running time: {duration} seconds")
# Log the model name, user input, response, and execution time in CSV format
with open("runtime_log.csv", "a") as log_file:
log_file.write(f"{user_input.model_name},{user_input.user_input},{response},{duration}\n")
return NurseResponse(nurse_response=response)
# TTS
from tts.tts import app as tts_app
app.mount("/tts", tts_app)
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |