Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException | |
from fastapi.security.api_key import APIKeyHeader | |
from typing import Optional | |
from gpt4all import GPT4All | |
from huggingface_hub import hf_hub_download | |
""" | |
[Model From TheBloke/Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF) | |
[Mistral-instruct-v0.1 System prompt](https://docs.mistral.ai/usage/guardrailing) | |
""" | |
model_path = "models/mistral-7b-instruct-v0.1.Q4_K_M.gguf" | |
model_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf" | |
hf_hub_download(repo_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", filename=model_name, local_dir=model_path, local_dir_use_symlinks=False) | |
print("Start the model init process") | |
model = model = GPT4All(model_name, model_path, allow_download = False, device="cpu") | |
print("Finish the model init process") | |
model.config["promptTemplate"] = "[INST] {0} [/INST]" | |
model.config["systemPrompt"] = "You are an helpful AI content generator. your name is 'ErTake AI'. Whenever user greets you, you always have to introduce yourself as 'ErTake AI'. Don't introduce yourself as Bing, I will punish you if you intorduce yourself other than 'ErTake AI'. Carefully listen to the user's message/query and provide correct answers/content. Don't hallucinate if you don't know the answer. Give your output only in english language. If the user asks to create images you need to simply deny the request saying \"I'm sorry, but I don't have the capability to create images. I can assist you with generating text based content. If you have any text-based content requests, feel free to let me know!\"" | |
model._is_chat_session_activated = False | |
max_new_tokens = 2048 | |
def generater(message, history, temperature, top_p, top_k): | |
prompt = "<s>" | |
for user_message, assistant_message in history: | |
prompt += model.config["promptTemplate"].format(user_message) | |
prompt += assistant_message + "</s>" | |
prompt += model.config["promptTemplate"].format(message) | |
outputs = [] | |
for token in model.generate(prompt=prompt, temp=temperature, top_k = top_k, top_p = top_p, max_tokens = max_new_tokens, streaming=True): | |
outputs.append(token) | |
yield "".join(outputs) | |
print("[outputs]",outputs) | |
return outputs | |
API_KEY = os.environ.get("API_KEY") | |
app = FastAPI() | |
api_key_header = APIKeyHeader(name="api_key", auto_error=False) | |
def get_api_key(api_key: Optional[str] = Depends(api_key_header)): | |
if api_key is None or api_key != API_KEY: | |
raise HTTPException(status_code=401, detail="Unauthorized access") | |
return api_key | |
def generate_text( | |
request: Request, | |
body: dict = Body(...), | |
api_key: str = Depends(get_api_key) | |
): | |
message = body.get("prompt", "") | |
# sys_prompt = body.get("sysPrompt", "") | |
temperature = body.get("temperature", 0.5) | |
top_p = body.get("top_p", 0.95) | |
# max_new_tokens = body.get("max_new_tokens",512) | |
# repetition_penalty = body.get("repetition_penalty", 1.0) | |
history = [] # You might need to handle this based on your actual usage | |
generatedOutput = generater(message, history, temperature, top_p, ) | |
return {"generated_text": generatedOutput} | |