MedAI / app.py
moriire's picture
Update app.py
6a34b4c verified
raw
history blame
3.67 kB
import fastapi
from fastapi.responses import JSONResponse
from time import time
#MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
import logging
import llama_cpp
import llama_cpp.llama_tokenizer
from pydantic import BaseModel
class GenModel(BaseModel):
question: str
system: str = "You are a story writing assistant."
temperature: float = 0.7
seed: int = 42
llama = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False,
n_ctx=4096,
n_gpu_layers=0,
chat_format="llama-2"
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Llama model
"""
try:
llm = Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
verbose=False,
n_ctx=4096,
n_threads=4,
n_gpu_layers=0,
)
llm = Llama(
model_path=MODEL_PATH,
chat_format="llama-2",
n_ctx=4096,
n_threads=8,
n_gpu_layers=0,
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
"""
app = fastapi.FastAPI()
@app.get("/")
def index():
return fastapi.responses.RedirectResponse(url="/docs")
@app.get("/health")
def health():
return {"status": "ok"}
# Chat Completion API
@app.get("/generate_stream/")
async def complete(gen:GenModel):
try:
st = time()
output = llama.create_chat_completion(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": gen['question']},
],
temperature=gen['temperature'],
seed=gen['seed'],
#stream=True
)
"""
for chunk in output:
delta = chunk['choices'][0]['delta']
if 'role' in delta:
print(delta['role'], end=': ')
elif 'content' in delta:
print(delta['content'], end='')
print(chunk)
"""
et = time()
output["time"] = et - st
return output
except Exception as e:
logger.error(f"Error in /complete endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)
# Chat Completion API
@app.post("/generate")
async def complete(
question: str,
system: str = "You are a story writing assistant.",
temperature: float = 0.7,
seed: int = 42,
) -> dict:
try:
st = time()
output = llama.create_chat_completion(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": question},
],
temperature=temperature,
seed=seed,
#stream=True
)
"""
for chunk in output:
delta = chunk['choices'][0]['delta']
if 'role' in delta:
print(delta['role'], end=': ')
elif 'content' in delta:
print(delta['content'], end='')
print(chunk)
"""
et = time()
output["time"] = et - st
return output
except Exception as e:
logger.error(f"Error in /complete endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)