MedAI / app /routers /llm.py
moriire's picture
auth added
014870a
raw
history blame
4.15 kB
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.models.user import User
from app.schemas.user import UserCreate, UserOut
from app.auth import create_access_token, get_current_user
import fastapi
from fastapi.responses import JSONResponse
from time import time
#from fastapi.middleware.cors import CORSMiddleware
import logging
import llama_cpp
import llama_cpp.llama_tokenizer
from pydantic import BaseModel
router = APIRouter(prefix="/llm", tags=["llm"])
class GenModel(BaseModel):
question: str
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English."
temperature: float = 0.8
seed: int = 101
mirostat_mode: int=2
mirostat_tau: float=4.0
mirostat_eta: float=1.1
class ChatModel(BaseModel):
question: list
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English."
temperature: float = 0.8
seed: int = 101
mirostat_mode: int=2
mirostat_tau: float=4.0
mirostat_eta: float=1.1
llm_chat = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False,
n_ctx=1024,
n_gpu_layers=0,
#chat_format="llama-2"
)
llm_generate = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False,
n_ctx=4096,
n_gpu_layers=0,
mirostat_mode=2,
mirostat_tau=4.0,
mirostat_eta=1.1
#chat_format="llama-2"
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@router.get("/")
def index():
return fastapi.responses.RedirectResponse(url="/docs")
@router.get("/health")
def health():
return {"status": "ok"}
# Chat Completion API
@router.post("/chat/")
async def chat(chatm:ChatModel):
try:
st = time()
output = llm_chat.create_chat_completion(
messages = chatm.question,
temperature = chatm.temperature,
seed = chatm.seed,
#stream=True
)
#print(output)
et = time()
output["time"] = et - st
#messages.append({'role': "assistant", "content": output['choices'][0]['message']['content']})
#print(messages)
return output
except Exception as e:
logger.error(f"Error in /complete endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)
# Chat Completion API
@router.post("/generate")
async def generate(gen:GenModel):
gen.system = "You are an helpful medical AI assistant."
gen.temperature = 0.5
gen.seed = 42
try:
st = time()
output = llm_generate.create_chat_completion(
messages=[
{"role": "system", "content": gen.system},
{"role": "user", "content": gen.question},
],
temperature = gen.temperature,
seed= gen.seed,
#stream=True,
#echo=True
)
"""
for chunk in output:
delta = chunk['choices'][0]['delta']
if 'role' in delta:
print(delta['role'], end=': ')
elif 'content' in delta:
print(delta['content'], end='')
#print(chunk)
"""
et = time()
output["time"] = et - st
return output
except Exception as e:
logger.error(f"Error in /generate endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)