|
from fastapi import APIRouter, Depends, HTTPException, status |
|
from sqlalchemy.orm import Session |
|
from app.db.database import get_db |
|
from app.models.user import User |
|
from app.schemas.user import UserCreate, UserOut |
|
from app.auth import create_access_token, get_current_user |
|
|
|
|
|
|
|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from time import time |
|
|
|
import logging |
|
import llama_cpp |
|
import llama_cpp.llama_tokenizer |
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/llm", tags=["llm"]) |
|
|
|
|
|
class GenModel(BaseModel): |
|
question: str |
|
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
|
|
class ChatModel(BaseModel): |
|
question: list |
|
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English." |
|
temperature: float = 0.8 |
|
seed: int = 101 |
|
mirostat_mode: int=2 |
|
mirostat_tau: float=4.0 |
|
mirostat_eta: float=1.1 |
|
llm_chat = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"), |
|
verbose=False, |
|
n_ctx=1024, |
|
n_gpu_layers=0, |
|
|
|
) |
|
llm_generate = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"), |
|
verbose=False, |
|
n_ctx=4096, |
|
n_gpu_layers=0, |
|
mirostat_mode=2, |
|
mirostat_tau=4.0, |
|
mirostat_eta=1.1 |
|
|
|
) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
@router.get("/") |
|
def index(): |
|
return fastapi.responses.RedirectResponse(url="/docs") |
|
|
|
@router.get("/health") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
@router.post("/chat/") |
|
async def chat(chatm:ChatModel): |
|
try: |
|
st = time() |
|
output = llm_chat.create_chat_completion( |
|
messages = chatm.question, |
|
temperature = chatm.temperature, |
|
seed = chatm.seed, |
|
|
|
) |
|
|
|
et = time() |
|
output["time"] = et - st |
|
|
|
|
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
@router.post("/generate") |
|
async def generate(gen:GenModel): |
|
gen.system = "You are an helpful medical AI assistant." |
|
gen.temperature = 0.5 |
|
gen.seed = 42 |
|
try: |
|
st = time() |
|
output = llm_generate.create_chat_completion( |
|
messages=[ |
|
{"role": "system", "content": gen.system}, |
|
{"role": "user", "content": gen.question}, |
|
], |
|
temperature = gen.temperature, |
|
seed= gen.seed, |
|
|
|
|
|
) |
|
""" |
|
for chunk in output: |
|
delta = chunk['choices'][0]['delta'] |
|
if 'role' in delta: |
|
print(delta['role'], end=': ') |
|
elif 'content' in delta: |
|
print(delta['content'], end='') |
|
#print(chunk) |
|
""" |
|
et = time() |
|
output["time"] = et - st |
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /generate endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|