Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,12 @@ from flask import Flask, request, Response
|
|
2 |
import logging
|
3 |
from llama_cpp import Llama
|
4 |
import threading
|
5 |
-
from huggingface_hub import snapshot_download
|
|
|
6 |
import gc
|
7 |
import os.path
|
|
|
|
|
8 |
|
9 |
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
|
10 |
SYSTEM_TOKEN = 1788
|
@@ -51,6 +54,29 @@ model = None
|
|
51 |
model_path = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name
|
52 |
app.logger.info('Model path: ' + model_path)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def init_model(context_size, enable_gpu=False, gpu_layer_number=35):
|
55 |
global model
|
56 |
|
@@ -221,18 +247,8 @@ def generate_response():
|
|
221 |
top_k = parameters.get("top_k", 30)
|
222 |
return_full_text = parameters.get("return_full_text", False)
|
223 |
|
224 |
-
|
225 |
-
# Generate the response
|
226 |
-
#system_tokens = get_system_tokens(model)
|
227 |
-
#tokens = system_tokens
|
228 |
-
|
229 |
-
#if preprompt != "":
|
230 |
-
# tokens = get_system_tokens_for_preprompt(model, preprompt)
|
231 |
-
#else:
|
232 |
tokens = get_system_tokens(model)
|
233 |
-
tokens.append(LINEBREAK_TOKEN)
|
234 |
-
#model.eval(tokens)
|
235 |
-
|
236 |
|
237 |
tokens = []
|
238 |
|
@@ -243,22 +259,13 @@ def generate_response():
|
|
243 |
message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
|
244 |
|
245 |
tokens.extend(message_tokens)
|
246 |
-
|
247 |
-
#app.logger.info('model.eval start')
|
248 |
-
#model.eval(tokens)
|
249 |
-
#app.logger.info('model.eval end')
|
250 |
-
|
251 |
-
#last_message = messages[-1]
|
252 |
-
#if last_message.get("from") == "assistant":
|
253 |
-
# last_message_tokens = get_message_tokens(model=model, role="bot", content=last_message.get("content", ""))
|
254 |
-
#else:
|
255 |
-
# last_message_tokens = get_message_tokens(model=model, role="user", content=last_message.get("content", ""))
|
256 |
|
257 |
tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
|
258 |
|
259 |
|
260 |
app.logger.info('Prompt:')
|
261 |
-
|
|
|
262 |
|
263 |
stop_generation = False
|
264 |
app.logger.info('Generate started')
|
@@ -271,8 +278,20 @@ def generate_response():
|
|
271 |
)
|
272 |
app.logger.info('Generator created')
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
# Use Response to stream tokens
|
275 |
-
return Response(
|
|
|
|
|
276 |
|
277 |
if __name__ == "__main__":
|
278 |
app.run(host="0.0.0.0", port=7860, debug=False, threaded=False)
|
|
|
2 |
import logging
|
3 |
from llama_cpp import Llama
|
4 |
import threading
|
5 |
+
from huggingface_hub import snapshot_download, Repository
|
6 |
+
import huggingface_hub
|
7 |
import gc
|
8 |
import os.path
|
9 |
+
import csv
|
10 |
+
from datetime import datetime
|
11 |
|
12 |
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
|
13 |
SYSTEM_TOKEN = 1788
|
|
|
54 |
model_path = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name
|
55 |
app.logger.info('Model path: ' + model_path)
|
56 |
|
57 |
+
DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat"
|
58 |
+
DATA_FILENAME = "data.csv"
|
59 |
+
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
60 |
+
|
61 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
62 |
+
app.logger.info("HF_TOKEN is None?", HF_TOKEN is None)
|
63 |
+
app.logger.info("hfh", huggingface_hub.__version__)
|
64 |
+
|
65 |
+
repo = Repository(
|
66 |
+
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
|
67 |
+
)
|
68 |
+
|
69 |
+
def log(request: str = '', response: str = ''):
|
70 |
+
if request or response:
|
71 |
+
with open(DATA_FILE, "a") as csvfile:
|
72 |
+
writer = csv.DictWriter(csvfile, fieldnames=["request", "response", "time"])
|
73 |
+
writer.writerow(
|
74 |
+
{"request": request, "response": response, "time": str(datetime.now())}
|
75 |
+
)
|
76 |
+
commit_url = repo.push_to_hub()
|
77 |
+
app.logger.info(commit_url)
|
78 |
+
|
79 |
+
|
80 |
def init_model(context_size, enable_gpu=False, gpu_layer_number=35):
|
81 |
global model
|
82 |
|
|
|
247 |
top_k = parameters.get("top_k", 30)
|
248 |
return_full_text = parameters.get("return_full_text", False)
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
tokens = get_system_tokens(model)
|
251 |
+
tokens.append(LINEBREAK_TOKEN)
|
|
|
|
|
252 |
|
253 |
tokens = []
|
254 |
|
|
|
259 |
message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
|
260 |
|
261 |
tokens.extend(message_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
|
264 |
|
265 |
|
266 |
app.logger.info('Prompt:')
|
267 |
+
request = model.detokenize(tokens[:CONTEXT_SIZE]).decode("utf-8", errors="ignore")
|
268 |
+
app.logger.info(request)
|
269 |
|
270 |
stop_generation = False
|
271 |
app.logger.info('Generate started')
|
|
|
278 |
)
|
279 |
app.logger.info('Generator created')
|
280 |
|
281 |
+
|
282 |
+
response_tokens = []
|
283 |
+
def generate_and_log_tokens(model, generator):
|
284 |
+
for token in generate_tokens(model, generator):
|
285 |
+
if token == model.token_eos(): # or (max_new_tokens is not None and i >= max_new_tokens):
|
286 |
+
log(request=request, response=model.detokenize(response_tokens).decode("utf-8", errors="ignore"))
|
287 |
+
break
|
288 |
+
response_tokens.append(token)
|
289 |
+
yield token
|
290 |
+
|
291 |
# Use Response to stream tokens
|
292 |
+
return Response(generate_and_log_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
|
293 |
+
|
294 |
+
|
295 |
|
296 |
if __name__ == "__main__":
|
297 |
app.run(host="0.0.0.0", port=7860, debug=False, threaded=False)
|