Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| from os import listdir | |
| from os.path import isdir | |
| from fastapi import FastAPI, HTTPException, Request, responses, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from llama_cpp import Llama | |
| from pydantic import BaseModel | |
| from enum import Enum | |
| from typing import Optional, Literal, Dict, List | |
| # MODEL LOADING, FUNCTIONS, AND TESTING | |
| print("Loading model...") | |
| PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True) | |
| # MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True) | |
| # n_gpu_layers=28, # Uncomment to use GPU acceleration | |
| # seed=1337, # Uncomment to set a specific seed | |
| # n_ctx=2048, # Uncomment to increase the context window | |
| #) | |
| print("Loading Translators.") | |
| from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator | |
| t = EnThTranslator() | |
| e = ThEnTranslator() | |
| def extract_restext(response, is_chat=False): | |
| return response['choices'][0]['message' if is_chat else 'text'].strip() | |
| def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): | |
| prompt = f"""<|begin_of_text|> | |
| <|start_header_id|> user <|end_header_id|> {question} <|eot_id|> | |
| <|start_header_id|> assistant <|end_header_id|>""" | |
| result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"])).replace("<|eot_id|>", "").replace("<|end_of_text|>", "") | |
| return result | |
| # def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): | |
| # result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True) | |
| # return result | |
| # TESTING THE MODEL | |
| print("Testing model...") | |
| assert ask_llama(PHllm, "Hello!, How are you today?", max_new_tokens=5) #Just checking that it can run | |
| print("Checking Translators.") | |
| assert t.translate("Hello!") == "สวัสดี!" | |
| assert e.translate("สวัสดี!") == "Hello!" | |
| print("Ready.") | |
| # START OF FASTAPI APP | |
| app = FastAPI( | |
| title = "Gemma Finetuned API", | |
| description="Gemma Finetuned API for Thai Open-ended question answering.", | |
| version="1.0.0", | |
| ) | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| # API DATA CLASSES | |
| class QuestionResponse(BaseModel): | |
| code: int = 200 | |
| question: Optional[str] = None | |
| answer: str = None | |
| config: Optional[dict] = None | |
| class ChatHistoryResponse(BaseModel): | |
| code: int = 200 | |
| chat_history: Dict[str, str] = None | |
| answer: str = None | |
| config: Optional[dict] = None | |
| class LlamaChatMessage(BaseModel): | |
| role: Literal["user", "assistant"] | |
| content: str | |
| # API ROUTES | |
| def docs(): | |
| "Redirects the user from the main page to the docs." | |
| return responses.RedirectResponse('./docs') | |
| async def ask_gemmaPhysics( | |
| prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"), | |
| temperature: float = Body(0.5, embed=True), | |
| repeat_penalty: float = Body(1.0, embed=True), | |
| max_new_tokens: int = Body(200, embed=True), | |
| translate_from_thai: bool = Body(False, embed=True) | |
| ) -> QuestionResponse: | |
| """ | |
| Ask a finetuned Gemma an physics question. | |
| NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. | |
| """ | |
| if prompt: | |
| try: | |
| print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}') | |
| if translate_from_thai: | |
| print("Translating content to EN.") | |
| prompt = e.translate(prompt) | |
| print(f"Asking the model with the question {prompt}") | |
| result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) | |
| print(f"Got Model Response: {result}") | |
| if translate_from_thai: | |
| result = t.translate(result) | |
| print(f"Translation Result: {result}") | |
| return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) | |
| except Exception as e: | |
| return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) | |
| else: | |
| return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |
| # @app.post('/chat/multiturn') | |
| # async def ask_llama3_Tuna( | |
| # chat_history: List[LlamaChatMessage] = Body(..., embed=True), | |
| # temperature: float = Body(0.5, embed=True), | |
| # repeat_penalty: float = Body(2.0, embed=True), | |
| # max_new_tokens: int = Body(200, embed=True) | |
| # ) -> ChatHistoryResponse: | |
| # """ | |
| # Chat with a finetuned Llama-3 model (in Thai). | |
| # Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. | |
| # NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF! | |
| # """ | |
| # if chat_history: | |
| # try: | |
| # print(f'Asking Llama3Tuna with the question "{chat_history}"') | |
| # result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) | |
| # print(f"Result: {result}") | |
| # return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) | |
| # except Exception as e: | |
| # return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history)) | |
| # else: | |
| # return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) | |