File size: 6,655 Bytes
36d1bec a4dc558 36d1bec 883f5ea 018b1f6 cc0f41d 883f5ea 36d1bec 018b1f6 36d1bec a4dc558 36d1bec a4dc558 36d1bec 018b1f6 36d1bec 018b1f6 36d1bec 018b1f6 36d1bec 018b1f6 36d1bec 018b1f6 36d1bec a4dc558 36d1bec a4dc558 36d1bec ad9c449 36d1bec 1f56bfa 018b1f6 9e8571d 018b1f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from flask import Flask, request, Response
import logging
import threading
from huggingface_hub import snapshot_download#, Repository
import huggingface_hub
import gc
import os.path
import xml.etree.ElementTree as ET
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime, timedelta
from llm_backend import LlmBackend
import json
import sys
llm = LlmBackend()
_lock = threading.Lock()
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') or "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
CONTEXT_SIZE = int(os.environ.get('CONTEXT_SIZE', '500'))
HF_CACHE_DIR = os.environ.get('HF_CACHE_DIR') or '/home/user/app/.cache'
USE_SYSTEM_PROMPT = os.environ.get('USE_SYSTEM_PROMPT', '').lower() == "true" or False
ENABLE_GPU = os.environ.get('ENABLE_GPU', '').lower() == "true" or False
GPU_LAYERS = int(os.environ.get('GPU_LAYERS', '0'))
CHAT_FORMAT = os.environ.get('CHAT_FORMAT') or 'llama-2'
REPO_NAME = os.environ.get('REPO_NAME') or 'IlyaGusev/saiga2_7b_gguf'
MODEL_NAME = os.environ.get('MODEL_NAME') or 'model-q4_K.gguf'
DATASET_REPO_URL = os.environ.get('DATASET_REPO_URL') or "https://huggingface.co/datasets/muryshev/saiga-chat"
DATA_FILENAME = os.environ.get('DATA_FILENAME') or "data-saiga-cuda-release.xml"
HF_TOKEN = os.environ.get("HF_TOKEN")
# Create a lock object
lock = threading.Lock()
app = Flask('llm_api')
app.logger.handlers.clear()
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
app.logger.addHandler(handler)
app.logger.setLevel(logging.DEBUG)
# Variable to store the last request time
last_request_time = datetime.now()
# Initialize the model when the application starts
#model_path = "../models/model-q4_K.gguf" # Replace with the actual model path
#MODEL_NAME = "model/ggml-model-q4_K.gguf"
#REPO_NAME = "IlyaGusev/saiga2_13b_gguf"
#MODEL_NAME = "model-q4_K.gguf"
#epo_name = "IlyaGusev/saiga2_70b_gguf"
#MODEL_NAME = "ggml-model-q4_1.gguf"
local_dir = '.'
if os.path.isdir('/data'):
app.logger.info('Persistent storage enabled')
model = None
MODEL_PATH = snapshot_download(repo_id=REPO_NAME, allow_patterns=MODEL_NAME, cache_dir=HF_CACHE_DIR) + '/' + MODEL_NAME
app.logger.info('Model path: ' + MODEL_PATH)
DATA_FILE = os.path.join("dataset", DATA_FILENAME)
app.logger.info("hfh: "+huggingface_hub.__version__)
# repo = Repository(
# local_dir="dataset", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
# )
# def log(req: str = '', resp: str = ''):
# if req or resp:
# element = ET.Element("row", {"time": str(datetime.now()) })
# req_element = ET.SubElement(element, "request")
# req_element.text = req
# resp_element = ET.SubElement(element, "response")
# resp_element.text = resp
# with open(DATA_FILE, "ab+") as xml_file:
# xml_file.write(ET.tostring(element, encoding="utf-8"))
# commit_url = repo.push_to_hub()
# app.logger.info(commit_url)
@app.route('/change_context_size', methods=['GET'])
def handler_change_context_size():
global stop_generation, model
stop_generation = True
new_size = int(request.args.get('size', CONTEXT_SIZE))
init_model(new_size, ENABLE_GPU, GPU_LAYERS)
return Response('Size changed', content_type='text/plain')
@app.route('/stop_generation', methods=['GET'])
def handler_stop_generation():
global stop_generation
stop_generation = True
return Response('Stopped', content_type='text/plain')
@app.route('/', methods=['GET', 'PUT', 'DELETE', 'PATCH'])
def generate_unknown_response():
app.logger.info('unknown method: '+request.method)
try:
request_payload = request.get_json()
app.logger.info('payload: '+request.get_json())
except Exception as e:
app.logger.info('payload empty')
return Response('What do you want?', content_type='text/plain')
response_tokens = bytearray()
def generate_and_log_tokens(user_request, generator):
global response_tokens, last_request_time
for token in llm.generate_tokens(generator):
if token == b'': # or (max_new_tokens is not None and i >= max_new_tokens):
last_request_time = datetime.now()
# log(json.dumps(user_request), response_tokens.decode("utf-8", errors="ignore"))
response_tokens = bytearray()
break
response_tokens.extend(token)
yield token
@app.route('/', methods=['POST'])
def generate_response():
app.logger.info('generate_response called')
data = request.get_json()
app.logger.info(data)
messages = data.get("messages", [])
preprompt = data.get("preprompt", "")
parameters = data.get("parameters", {})
# Extract parameters from the request
p = {
'temperature': parameters.get("temperature", 0.01),
'truncate': parameters.get("truncate", 1000),
'max_new_tokens': parameters.get("max_new_tokens", 1024),
'top_p': parameters.get("top_p", 0.85),
'repetition_penalty': parameters.get("repetition_penalty", 1.2),
'top_k': parameters.get("top_k", 30),
'return_full_text': parameters.get("return_full_text", False)
}
generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p, use_system_prompt=USE_SYSTEM_PROMPT)
app.logger.info('Generator created')
# Use Response to stream tokens
return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True)
def init_model():
llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS)
# Function to check if no requests were made in the last 5 minutes
def check_last_request_time():
global last_request_time
current_time = datetime.now()
if (current_time - last_request_time).total_seconds() > 300: # 5 minutes in seconds
llm.unload_model()
app.logger.info(f"Model unloaded at {current_time}")
else:
app.logger.info(f"No action needed at {current_time}")
if __name__ == "__main__":
init_model()
# scheduler = BackgroundScheduler()
# scheduler.add_job(check_last_request_time, trigger='interval', minutes=1)
# scheduler.start()
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)
|