Spaces:
Running
Running
import logging | |
from re import split, sub | |
from threading import Lock | |
from time import sleep | |
from typing import Tuple, Dict | |
try: | |
from extensions.telegram_bot.source.generators.abstract_generator import AbstractGenerator | |
except ImportError: | |
from source.generators.abstract_generator import AbstractGenerator | |
try: | |
import extensions.telegram_bot.source.const as const | |
import extensions.telegram_bot.source.utils as utils | |
import extensions.telegram_bot.source.generator as generator | |
from extensions.telegram_bot.source.user import User as User | |
from extensions.telegram_bot.source.conf import cfg | |
from extensions.telegram_bot.source.conf import cfg | |
except ImportError: | |
import source.const as const | |
import source.utils as utils | |
import source.generator as generator | |
from source.user import User as User | |
from source.conf import cfg | |
# Define generator lock to prevent GPU overloading | |
generator_lock = Lock() | |
# Generator obj | |
debug_flag = True | |
# ==================================================================================== | |
# TEXT LOGIC | |
async def aget_answer(text_in: str, user: User, bot_mode: str, generation_params: Dict, name_in="") -> Tuple[str, str]: | |
return await get_answer(text_in, user, bot_mode, generation_params, name_in) | |
def get_answer(text_in: str, user: User, bot_mode: str, generation_params: Dict, name_in=""): | |
# additional delay option | |
if cfg.answer_delay > 0: | |
sleep(cfg.answer_delay) | |
# if generation will fail, return "fail" answer | |
answer = const.GENERATOR_FAIL | |
# default result action - message | |
return_msg_action = const.MSG_SEND | |
# if user is default equal to user1 | |
name_in = user.name1 if name_in == "" else name_in | |
# for regeneration result checking | |
previous_result = "" | |
# acquire generator lock if we can | |
generator_lock.acquire(timeout=cfg.generation_timeout) | |
# user_input preprocessing | |
try: | |
# Preprocessing: actions which return result immediately: | |
if text_in.startswith(tuple(cfg.permanent_change_name2_prefixes)): | |
# If user_in starts with perm_prefix - just replace name2 | |
user.name2 = text_in[2:] | |
return_msg_action = const.MSG_SYSTEM | |
generator_lock.release() | |
return "New bot name: " + user.name2, return_msg_action | |
if text_in.startswith(tuple(cfg.permanent_change_name1_prefixes)): | |
# If user_in starts with perm_prefix - just replace name2 | |
user.name1 = text_in[2:] | |
return_msg_action = const.MSG_SYSTEM | |
generator_lock.release() | |
return "New user name: " + user.name1, return_msg_action | |
if text_in.startswith(tuple(cfg.permanent_add_context_prefixes)): | |
# If user_in starts with perm_prefix - just replace name2 | |
user.context += "\n" + text_in[2:] | |
return_msg_action = const.MSG_SYSTEM | |
generator_lock.release() | |
return "Added to context: " + text_in[2:], return_msg_action | |
if text_in.startswith(tuple(cfg.replace_prefixes)): | |
# If user_in starts with replace_prefix - fully replace last message | |
user.change_last_message(history_out=text_in[1:]) | |
return_msg_action = const.MSG_DEL_LAST | |
generator_lock.release() | |
return user.history_last_out, return_msg_action | |
if text_in == const.GENERATOR_MODE_DEL_WORD: | |
# If user_in starts with replace_prefix - fully replace last message | |
# get and change last message | |
last_message = user.history_last_out | |
last_word = split(r"\n|\.+ +|: +|! +|\? +|\' +|\" +|; +|\) +|\* +", last_message)[-1] | |
if len(last_word) == 0 and len(last_message) > 1: | |
last_word = " " | |
new_last_message = last_message[: -(len(last_word))] | |
new_last_message = new_last_message.strip() | |
if len(new_last_message) == 0: | |
return_msg_action = const.MSG_NOTHING_TO_DO | |
else: | |
user.change_last_message(history_out=new_last_message) | |
generator_lock.release() | |
return user.history_last_out, return_msg_action | |
# Preprocessing: actions which not depends on user input: | |
if bot_mode in [const.MODE_QUERY]: | |
user.history = [] | |
# if regenerate - msg_id the same, text and name the same. But history clearing: | |
if text_in == const.GENERATOR_MODE_REGENERATE: | |
if str(user.msg_id[-1]) not in user.previous_history: | |
user.previous_history.update({str(user.msg_id[-1]): []}) | |
user.previous_history[str(user.msg_id[-1])].append(user.history_last_out) | |
text_in = user.text_in[-1] | |
name_in = user.name_in[-1] | |
last_msg_id = user.msg_id[-1] | |
user.truncate_last_message() | |
user.msg_id.append(last_msg_id) | |
# Preprocessing: add user_in/names/whitespaces to history in right order depends on mode: | |
if bot_mode in [const.MODE_NOTEBOOK]: | |
# If notebook mode - append to history only user_in, no additional preparing; | |
user.text_in.append(text_in) | |
user.history_append("", text_in) | |
elif text_in == const.GENERATOR_MODE_IMPERSONATE: | |
# if impersonate - append to history only "name1:", no adding "" history | |
# line to prevent bug in history sequence, add "name1:" prefix for generation | |
user.text_in.append(text_in) | |
user.name_in.append(name_in) | |
user.history_append("", name_in + ":") | |
elif text_in == const.GENERATOR_MODE_NEXT: | |
# if user_in is "" - no user text, it is like continue generation adding "" history line | |
# to prevent bug in history sequence, add "name2:" prefix for generation | |
user.text_in.append(text_in) | |
user.name_in.append(name_in) | |
user.history_append("", user.name2 + ":") | |
elif text_in == const.GENERATOR_MODE_CONTINUE: | |
# if user_in is "" - no user text, it is like continue generation | |
# adding "" history line to prevent bug in history sequence, add "name2:" prefix for generation | |
pass | |
elif text_in.startswith(tuple(cfg.sd_api_prefixes)): | |
# If user_in starts with prefix - impersonate-like (if you try to get "impersonate view") | |
# adding "" line to prevent bug in history sequence, user_in is prefix for bot answer | |
user.msg_id.append(0) | |
user.text_in.append(text_in) | |
user.name_in.append(name_in) | |
if len(text_in) == 1: | |
user.history_append("", cfg.sd_api_prompt_self) | |
else: | |
user.history_append("", cfg.sd_api_prompt_of.replace("OBJECT", text_in[1:].strip())) | |
return_msg_action = const.MSG_SD_API | |
elif text_in.startswith(tuple(cfg.impersonate_prefixes)): | |
# If user_in starts with prefix - impersonate-like (if you try to get "impersonate view") | |
# adding "" line to prevent bug in history sequence, user_in is prefix for bot answer | |
user.text_in.append(text_in) | |
user.name_in.append(text_in[1:]) | |
user.history_append("", text_in[1:] + ":") | |
else: | |
# If not notebook/impersonate/continue mode then ordinary chat preparing | |
# add "name1&2:" to user and bot message (generation from name2 point of view); | |
user.text_in.append(text_in) | |
user.name_in.append(name_in) | |
user.history_append(name_in + ": " + text_in, user.name2 + ":") | |
except Exception as exception: | |
generator_lock.release() | |
logging.error("get_answer (prepare text part) " + str(exception) + str(exception.args)) | |
# Text processing with LLM | |
try: | |
# Set eos_token and stopping_strings. | |
stopping_strings = generation_params["stopping_strings"].copy() | |
eos_token = generation_params["eos_token"] | |
if bot_mode in [const.MODE_CHAT, const.MODE_CHAT_R, const.MODE_ADMIN]: | |
stopping_strings += [ | |
name_in + ":", | |
user.name1 + ":", | |
user.name2 + ":", | |
] | |
if cfg.bot_prompt_end != "": | |
stopping_strings.append(cfg.bot_prompt_end) | |
# adjust context/greeting/example | |
if user.context.strip().endswith("\n"): | |
context = f"{user.context.strip()}" | |
else: | |
context = f"{user.context.strip()}\n" | |
context = cfg.context_prompt_begin + context + cfg.context_prompt_end | |
if len(user.example) > 0: | |
example = user.example | |
else: | |
example = "" | |
if len(user.greeting) > 0: | |
greeting = "\n" + user.name2 + ": " + user.greeting | |
else: | |
greeting = "" | |
# Make prompt: context + example + conversation history | |
available_len = generation_params["truncation_length"] | |
context_len = generator.get_tokens_count(context) | |
available_len -= context_len | |
if available_len < 0: | |
available_len = 0 | |
logging.info("telegram_bot - CONTEXT IS TOO LONG!!!") | |
conversation = [example, greeting] | |
for i in user.history: | |
if len(i["in"]) > 0: | |
conversation.append("".join([cfg.user_prompt_begin, i["in"], cfg.user_prompt_end])) | |
if len(i["out"]) > 0: | |
conversation.append("".join([cfg.bot_prompt_begin, i["out"], cfg.bot_prompt_end])) | |
if len(cfg.bot_prompt_end): | |
conversation[-1] = conversation[-1][: -len(cfg.bot_prompt_end)] | |
prompt = "" | |
for s in reversed(conversation): | |
s = "\n" + s if len(s) > 0 else s | |
s_len = generator.get_tokens_count(s) | |
if available_len >= s_len: | |
prompt = s + prompt | |
available_len -= s_len | |
else: | |
break | |
prompt = context + prompt | |
prompt = sub( | |
r": +", | |
": ", | |
prompt, | |
) | |
# Generate! | |
if debug_flag: | |
print(prompt) | |
answer = generator.generate_answer( | |
prompt=prompt, | |
generation_params=generation_params, | |
eos_token=eos_token, | |
stopping_strings=stopping_strings, | |
default_answer=answer, | |
turn_template=user.turn_template, | |
) | |
if debug_flag: | |
print(answer) | |
# Truncate prompt prefix/postfix | |
if len(cfg.bot_prompt_end) > 0 and answer.endswith(cfg.bot_prompt_end): | |
answer = answer[: -len(cfg.bot_prompt_end)] | |
if len(cfg.bot_prompt_end) > 2 and answer.endswith(cfg.bot_prompt_end[:-1]): | |
answer = answer[: -len(cfg.bot_prompt_end[:-1])] | |
if len(cfg.bot_prompt_begin) > 0 and answer.startswith(cfg.bot_prompt_begin): | |
answer = answer[: -len(cfg.bot_prompt_begin)] | |
# If generation result zero length - return "Empty answer." | |
if len(answer) < 1: | |
answer = const.GENERATOR_EMPTY_ANSWER | |
# Final return | |
if answer not in [const.GENERATOR_EMPTY_ANSWER, const.GENERATOR_FAIL]: | |
# if everything ok - add generated answer in history and return | |
# last | |
for end in stopping_strings: | |
if answer.endswith(end): | |
answer = answer[: -len(end)] | |
user.change_last_message(history_out=user.history_last_out + " " + answer) | |
generator_lock.release() | |
if len(user.msg_id) > 0: | |
if str(user.msg_id[-1]) in user.previous_history: | |
if user.previous_history[str(user.msg_id[-1])][-1] == user.history_last_out: | |
return_msg_action = const.MSG_NOTHING_TO_DO | |
return user.history_last_out, return_msg_action | |
except Exception as exception: | |
logging.error("get_answer (generator part) " + str(exception) + str(exception.args)) | |
# anyway, release generator lock. Then return | |
generator_lock.release() | |
return_msg_action = const.MSG_SYSTEM | |
return user.history_last_out, return_msg_action | |