|
import ast |
|
import copy |
|
import functools |
|
import json |
|
import os |
|
import tempfile |
|
import time |
|
import traceback |
|
import uuid |
|
import filelock |
|
|
|
from enums import LangChainMode, LangChainAction, no_model_str, LangChainTypes, langchain_modes_intrinsic, \ |
|
DocumentSubset, unknown_prompt_type, my_db_state0, selection_docs_state0, requests_state0, roles_state0, noneset, \ |
|
images_num_max_dict, image_batch_image_prompt0, image_batch_final_prompt0, images_limit_max_new_tokens, \ |
|
images_limit_max_new_tokens_list |
|
from model_utils import model_lock_to_state |
|
from tts_utils import combine_audios |
|
from utils import _save_generate_tokens, clear_torch_cache, remove, save_generate_output, str_to_list, \ |
|
get_accordion_named, check_input_type, download_image, deepcopy_by_pickle_object |
|
from db_utils import length_db1 |
|
from evaluate_params import input_args_list, eval_func_param_names, key_overrides, in_model_state_and_evaluate |
|
from vision.utils_vision import process_file_list |
|
|
|
|
|
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, plain_api=False, verifier=False, kwargs={}, |
|
my_db_state1=None, |
|
selection_docs_state1=None, |
|
requests_state1=None, |
|
roles_state1=None, |
|
model_states=[], |
|
**kwargs1): |
|
is_public = kwargs1.get('is_public', False) |
|
verbose = kwargs1.get('verbose', False) |
|
|
|
if my_db_state1 is None: |
|
if 'my_db_state0' in kwargs1 and kwargs1['my_db_state0'] is not None: |
|
my_db_state1 = kwargs1['my_db_state0'] |
|
else: |
|
my_db_state1 = copy.deepcopy(my_db_state0) |
|
if selection_docs_state1 is None: |
|
if 'selection_docs_state0' in kwargs1 and kwargs1['selection_docs_state0'] is not None: |
|
selection_docs_state1 = kwargs1['selection_docs_state0'] |
|
else: |
|
selection_docs_state1 = copy.deepcopy(selection_docs_state0) |
|
if requests_state1 is None: |
|
if 'requests_state0' in kwargs1 and kwargs1['requests_state0'] is not None: |
|
requests_state1 = kwargs1['requests_state0'] |
|
else: |
|
requests_state1 = copy.deepcopy(requests_state0) |
|
if roles_state1 is None: |
|
if 'roles_state0' in kwargs1 and kwargs1['roles_state0'] is not None: |
|
roles_state1 = kwargs1['roles_state0'] |
|
else: |
|
roles_state1 = copy.deepcopy(roles_state0) |
|
kwargs_eval_pop_keys = ['selection_docs_state0', 'requests_state0', 'roles_state0'] |
|
for k in kwargs_eval_pop_keys: |
|
if k in kwargs1: |
|
kwargs1.pop(k) |
|
|
|
|
|
|
|
args_list = list(args1) |
|
if str_api: |
|
if plain_api: |
|
if not verifier: |
|
|
|
args_list.insert(0, kwargs['model_state_none'].copy()) |
|
else: |
|
args_list.insert(0, kwargs['verifier_model_state0'].copy()) |
|
args_list.insert(1, my_db_state1.copy()) |
|
args_list.insert(2, selection_docs_state1.copy()) |
|
args_list.insert(3, requests_state1.copy()) |
|
args_list.insert(4, roles_state1.copy()) |
|
user_kwargs = args_list[len(input_args_list)] |
|
assert isinstance(user_kwargs, str) |
|
user_kwargs = ast.literal_eval(user_kwargs) |
|
else: |
|
assert not plain_api |
|
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])} |
|
|
|
|
|
|
|
if 'answer_with_sources' not in user_kwargs: |
|
kwargs1['answer_with_sources'] = -1 |
|
if 'sources_show_text_in_accordion' not in user_kwargs: |
|
kwargs1['sources_show_text_in_accordion'] = False |
|
if 'append_sources_to_chat' not in user_kwargs: |
|
kwargs1['append_sources_to_chat'] = False |
|
if 'append_sources_to_answer' not in user_kwargs: |
|
kwargs1['append_sources_to_answer'] = False |
|
if 'show_link_in_sources' not in user_kwargs: |
|
kwargs1['show_link_in_sources'] = False |
|
|
|
|
|
|
|
|
|
|
|
user_kwargs['chat'] = False |
|
if 'stream_output' not in user_kwargs: |
|
user_kwargs['stream_output'] = False |
|
if plain_api: |
|
user_kwargs['stream_output'] = False |
|
if 'langchain_mode' not in user_kwargs: |
|
|
|
if LangChainMode.LLM.value in kwargs['langchain_modes']: |
|
user_kwargs['langchain_mode'] = LangChainMode.LLM.value |
|
elif len(kwargs['langchain_modes']) >= 1: |
|
user_kwargs['langchain_mode'] = kwargs['langchain_modes'][0] |
|
else: |
|
|
|
user_kwargs['langchain_mode'] = LangChainMode.DISABLED.value |
|
if 'langchain_action' not in user_kwargs: |
|
user_kwargs['langchain_action'] = LangChainAction.QUERY.value |
|
if 'langchain_agents' not in user_kwargs: |
|
user_kwargs['langchain_agents'] = [] |
|
|
|
if 'instruction' in user_kwargs and 'instruction_nochat' not in user_kwargs: |
|
user_kwargs['instruction_nochat'] = user_kwargs['instruction'] |
|
if 'iinput' in user_kwargs and 'iinput_nochat' not in user_kwargs: |
|
user_kwargs['iinput_nochat'] = user_kwargs['iinput'] |
|
if 'visible_models' not in user_kwargs: |
|
if kwargs['visible_models']: |
|
if isinstance(kwargs['visible_models'], int): |
|
user_kwargs['visible_models'] = [kwargs['visible_models']] |
|
elif isinstance(kwargs['visible_models'], list): |
|
|
|
user_kwargs['visible_models'] = [kwargs['visible_models'][0]] |
|
else: |
|
user_kwargs['visible_models'] = [0] |
|
else: |
|
|
|
user_kwargs['visible_models'] = [0] |
|
if 'visible_vision_models' not in user_kwargs or user_kwargs['visible_vision_models'] is None: |
|
|
|
|
|
user_kwargs['visible_vision_models'] = 'auto' |
|
|
|
if 'h2ogpt_key' not in user_kwargs: |
|
user_kwargs['h2ogpt_key'] = None |
|
if 'system_prompt' in user_kwargs and user_kwargs['system_prompt'] is None: |
|
|
|
user_kwargs['system_prompt'] = 'None' |
|
|
|
if 'chatbot_role' not in user_kwargs: |
|
user_kwargs['chatbot_role'] = 'None' |
|
if 'speaker' not in user_kwargs: |
|
user_kwargs['speaker'] = 'None' |
|
|
|
set1 = set(list(default_kwargs1.keys())) |
|
set2 = set(eval_func_param_names) |
|
assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2)) |
|
|
|
|
|
|
|
model_state1 = args_list[0] |
|
my_db_state1 = args_list[1] |
|
selection_docs_state1 = args_list[2] |
|
requests_state1 = args_list[3] |
|
roles_state1 = args_list[4] |
|
|
|
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k |
|
in eval_func_param_names] |
|
assert len(args_list) == len(eval_func_param_names) |
|
|
|
|
|
|
|
model_lock_client = args_list[eval_func_param_names.index('model_lock')] |
|
if model_lock_client: |
|
|
|
|
|
model_state1 = model_lock_to_state(model_lock_client, cache_model_state=True, **kwargs) |
|
elif len(model_states) >= 1: |
|
visible_models1 = args_list[eval_func_param_names.index('visible_models')] |
|
model_active_choice1 = visible_models_to_model_choice(visible_models1, model_states, api=True) |
|
model_state1 = model_states[model_active_choice1 % len(model_states)] |
|
|
|
for key in key_overrides: |
|
if user_kwargs.get(key) is None and model_state1.get(key) is not None: |
|
args_list[eval_func_param_names.index(key)] = model_state1[key] |
|
if isinstance(model_state1, dict) and \ |
|
'tokenizer' in model_state1 and \ |
|
hasattr(model_state1['tokenizer'], 'model_max_length'): |
|
|
|
|
|
buffer = 0 |
|
args_list[eval_func_param_names.index('max_new_tokens')] = min( |
|
args_list[eval_func_param_names.index('max_new_tokens')], |
|
model_state1['tokenizer'].model_max_length - buffer) |
|
|
|
|
|
|
|
|
|
|
|
if 'visible_models' in model_state1 and model_state1['visible_models'] is not None: |
|
assert isinstance(model_state1['visible_models'], (int, str, list, tuple)) |
|
which_model = visible_models_to_model_choice(model_state1['visible_models'], model_states) |
|
args_list[eval_func_param_names.index('visible_models')] = which_model |
|
if 'visible_vision_models' in model_state1 and model_state1['visible_vision_models'] is not None: |
|
assert isinstance(model_state1['visible_vision_models'], (int, str, list, tuple)) |
|
which_model = visible_models_to_model_choice(model_state1['visible_vision_models'], model_states) |
|
args_list[eval_func_param_names.index('visible_vision_models')] = which_model |
|
if 'h2ogpt_key' in model_state1 and model_state1['h2ogpt_key'] is not None: |
|
|
|
|
|
assert isinstance(model_state1['h2ogpt_key'], str) |
|
args_list[eval_func_param_names.index('h2ogpt_key')] = model_state1['h2ogpt_key'] |
|
|
|
|
|
|
|
instruction_nochat1 = args_list[eval_func_param_names.index('instruction_nochat')] or \ |
|
args_list[eval_func_param_names.index('instruction')] |
|
args_list[eval_func_param_names.index('instruction_nochat')] = \ |
|
args_list[eval_func_param_names.index('instruction')] = \ |
|
instruction_nochat1 |
|
history = [[instruction_nochat1, None]] |
|
|
|
requests_state1_bot = None |
|
args_list_bot = args_list + [model_state1, my_db_state1, selection_docs_state1, requests_state1_bot, |
|
roles_state1] + [history] |
|
|
|
|
|
history, fun1, langchain_mode1, db1, requests_state1, \ |
|
valid_key, h2ogpt_key1, \ |
|
max_time1, stream_output1, \ |
|
chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, langchain_action1, \ |
|
image_files_to_delete = \ |
|
prep_bot(*args_list_bot, kwargs_eval=kwargs1, plain_api=plain_api, kwargs=kwargs, verbose=verbose) |
|
|
|
save_dict = dict() |
|
ret = {'error': "No response", 'sources': [], 'sources_str': '', 'prompt_raw': instruction_nochat1, |
|
'llm_answers': []} |
|
ret_old = '' |
|
history_str_old = '' |
|
error_old = '' |
|
audios = [] |
|
last_yield = None |
|
res_dict = {} |
|
try: |
|
tgen0 = time.time() |
|
for res in get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, |
|
tts_speed1, |
|
langchain_action1, |
|
langchain_mode1, |
|
kwargs=kwargs, |
|
api=True, |
|
verbose=verbose): |
|
history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio1 = res |
|
res_dict = {} |
|
res_dict['response'] = history[-1][1] or '' |
|
res_dict['error'] = error |
|
res_dict['sources'] = sources |
|
res_dict['sources_str'] = sources_str |
|
res_dict['prompt_raw'] = prompt_raw |
|
res_dict['llm_answers'] = llm_answers |
|
res_dict['save_dict'] = save_dict |
|
res_dict['audio'] = audio1 |
|
|
|
error = res_dict.get('error', '') |
|
sources = res_dict.get('sources', []) |
|
save_dict = res_dict.get('save_dict', {}) |
|
|
|
|
|
save_dict['error'] = error |
|
save_dict['sources'] = sources |
|
save_dict['valid_key'] = valid_key |
|
save_dict['h2ogpt_key'] = h2ogpt_key1 |
|
|
|
|
|
image_file_check = args_list[eval_func_param_names.index('image_file')] |
|
save_dict['image_file_present'] = len(image_file_check) if \ |
|
isinstance(image_file_check, (str, list, tuple)) else 0 |
|
text_context_list_check = args_list[eval_func_param_names.index('text_context_list')] |
|
save_dict['text_context_list_present'] = len(text_context_list_check) if \ |
|
isinstance(text_context_list_check, (list, tuple)) else 0 |
|
|
|
if str_api and plain_api: |
|
save_dict['which_api'] = 'str_plain_api' |
|
elif str_api: |
|
save_dict['which_api'] = 'str_api' |
|
elif plain_api: |
|
save_dict['which_api'] = 'plain_api' |
|
else: |
|
save_dict['which_api'] = 'nochat_api' |
|
if 'extra_dict' not in save_dict: |
|
save_dict['extra_dict'] = {} |
|
if requests_state1: |
|
save_dict['extra_dict'].update(requests_state1) |
|
else: |
|
save_dict['extra_dict'].update(dict(username='NO_REQUEST')) |
|
|
|
if is_public: |
|
|
|
if 'save_dict' in res_dict and isinstance(res_dict['save_dict'], dict): |
|
res_dict['save_dict'].pop('inference_server', None) |
|
if 'extra_dict' in res_dict['save_dict'] and isinstance(res_dict['save_dict']['extra_dict'], |
|
dict): |
|
res_dict['save_dict']['extra_dict'].pop('inference_server', None) |
|
|
|
|
|
if str_api: |
|
|
|
res_dict_yield = res_dict.copy() |
|
|
|
only_stream = ['response', 'llm_answers', 'audio'] |
|
for key in res_dict: |
|
if key not in only_stream: |
|
if isinstance(res_dict[key], str): |
|
res_dict_yield[key] = '' |
|
elif isinstance(res_dict[key], list): |
|
res_dict_yield[key] = [] |
|
elif isinstance(res_dict[key], dict): |
|
res_dict_yield[key] = {} |
|
else: |
|
print("Unhandled pop: %s" % key) |
|
res_dict_yield.pop(key) |
|
ret = res_dict_yield |
|
elif kwargs['langchain_mode'] == 'Disabled': |
|
ret = fix_text_for_gradio(res_dict['response'], fix_latex_dollars=False, |
|
fix_angle_brackets=False) |
|
else: |
|
ret = '<br>' + fix_text_for_gradio(res_dict['response'], fix_latex_dollars=False, |
|
fix_angle_brackets=False) |
|
|
|
do_yield = False |
|
could_yield = ret != ret_old |
|
if kwargs['gradio_api_use_same_stream_limits']: |
|
history_str = str(ret['response'] if isinstance(ret, dict) else str(ret)) |
|
delta_history = abs(len(history_str) - len(str(history_str_old))) |
|
|
|
enough_data = delta_history > kwargs['gradio_ui_stream_chunk_size'] or (error != error_old) |
|
beyond_min_time = last_yield is None or \ |
|
last_yield is not None and \ |
|
(time.time() - last_yield) > kwargs['gradio_ui_stream_chunk_min_seconds'] |
|
do_yield |= enough_data and beyond_min_time |
|
|
|
enough_time = last_yield is None or \ |
|
last_yield is not None and \ |
|
(time.time() - last_yield) > kwargs['gradio_ui_stream_chunk_seconds'] |
|
do_yield |= enough_time and could_yield |
|
|
|
else: |
|
do_yield = could_yield |
|
|
|
if stream_output1 and do_yield: |
|
last_yield = time.time() |
|
|
|
if isinstance(ret, dict): |
|
ret_old = ret.copy() |
|
from tts_utils import combine_audios |
|
ret['audio'] = combine_audios(audios, audio=audio1, sr=24000 if chatbot_role1 else 16000, |
|
expect_bytes=kwargs['return_as_byte'], verbose=verbose) |
|
audios = [] |
|
yield ret |
|
else: |
|
ret_old = ret |
|
yield ret |
|
|
|
|
|
|
|
history_str_old = str(ret_old['response'] if isinstance(ret_old, dict) else str(ret_old)) |
|
else: |
|
|
|
audios.append(res_dict['audio']) |
|
if time.time() - tgen0 > max_time1 + 10: |
|
msg = "Took too long evaluate_nochat: %s" % (time.time() - tgen0) |
|
if str_api: |
|
res_dict['save_dict']['extra_dict']['timeout'] = time.time() - tgen0 |
|
res_dict['save_dict']['error'] = msg |
|
if verbose: |
|
print(msg, flush=True) |
|
break |
|
|
|
|
|
|
|
if str_api: |
|
res_dict['save_dict']['extra_dict'] = _save_generate_tokens(res_dict.get('response', ''), |
|
res_dict.get('save_dict', {}).get( |
|
'extra_dict', {})) |
|
ret = res_dict.copy() |
|
if isinstance(ret, dict): |
|
from tts_utils import combine_audios |
|
ret['audio'] = combine_audios(audios, audio=None, |
|
expect_bytes=kwargs['return_as_byte']) |
|
yield ret |
|
|
|
except Exception as e: |
|
ex = traceback.format_exc() |
|
if verbose: |
|
print("Error in evaluate_nochat: %s" % ex, flush=True) |
|
if str_api: |
|
ret = {'error': str(e), 'error_ex': str(ex), 'sources': [], 'sources_str': '', 'prompt_raw': '', |
|
'llm_answers': []} |
|
yield ret |
|
raise |
|
finally: |
|
clear_torch_cache(allow_skip=True) |
|
db1s = my_db_state1 |
|
clear_embeddings(user_kwargs['langchain_mode'], kwargs['db_type'], db1s, kwargs['dbs']) |
|
for image_file1 in image_files_to_delete: |
|
if image_file1 and os.path.isfile(image_file1): |
|
remove(image_file1) |
|
save_dict['save_dir'] = kwargs['save_dir'] |
|
save_generate_output(**save_dict) |
|
|
|
|
|
def visible_models_to_model_choice(visible_models1, model_states1, api=False): |
|
if isinstance(visible_models1, list): |
|
assert len( |
|
visible_models1) >= 1, "Invalid visible_models1=%s, can only be single entry" % visible_models1 |
|
|
|
model_active_choice1 = visible_models1[0] |
|
elif isinstance(visible_models1, (str, int)): |
|
model_active_choice1 = visible_models1 |
|
else: |
|
assert isinstance(visible_models1, type(None)), "Invalid visible_models1=%s" % visible_models1 |
|
model_active_choice1 = visible_models1 |
|
if model_active_choice1 is not None: |
|
if isinstance(model_active_choice1, str): |
|
display_model_list = [x['display_name'] for x in model_states1] |
|
if model_active_choice1 in display_model_list: |
|
model_active_choice1 = display_model_list.index(model_active_choice1) |
|
else: |
|
|
|
if api: |
|
raise ValueError( |
|
"Invalid model %s, valid models are: %s" % (model_active_choice1, display_model_list)) |
|
model_active_choice1 = 0 |
|
else: |
|
model_active_choice1 = 0 |
|
return model_active_choice1 |
|
|
|
|
|
def clear_embeddings(langchain_mode1, db_type, db1s, dbs=None): |
|
|
|
if db_type in ['chroma', 'chroma_old'] and langchain_mode1 not in ['LLM', 'Disabled', None, '']: |
|
from gpt_langchain import clear_embedding, length_db1 |
|
if dbs is not None: |
|
db = dbs.get(langchain_mode1) |
|
if db is not None and not isinstance(db, str): |
|
clear_embedding(db) |
|
if db1s is not None and langchain_mode1 in db1s: |
|
db1 = db1s[langchain_mode1] |
|
if len(db1) == length_db1(): |
|
clear_embedding(db1[0]) |
|
|
|
|
|
def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True, fix_angle_brackets=True): |
|
if isinstance(text, tuple): |
|
|
|
return text |
|
|
|
if not isinstance(text, str): |
|
|
|
text = str(text) |
|
|
|
if fix_latex_dollars: |
|
ts = text.split('```') |
|
for parti, part in enumerate(ts): |
|
inside = parti % 2 == 1 |
|
if not inside: |
|
ts[parti] = ts[parti].replace('$', '﹩') |
|
text = '```'.join(ts) |
|
|
|
if fix_new_lines: |
|
|
|
|
|
|
|
|
|
ts = text.split('```') |
|
for parti, part in enumerate(ts): |
|
inside = parti % 2 == 1 |
|
if not inside: |
|
ts[parti] = ts[parti].replace('\n', '<br>') |
|
text = '```'.join(ts) |
|
if fix_angle_brackets: |
|
|
|
ts = text.split('```') |
|
for parti, part in enumerate(ts): |
|
inside = parti % 2 == 1 |
|
if not inside: |
|
if '<a href' not in ts[parti] and \ |
|
'<img src=' not in ts[parti] and \ |
|
'<div ' not in ts[parti] and \ |
|
'</div>' not in ts[parti] and \ |
|
'<details><summary>' not in ts[parti]: |
|
|
|
ts[parti] = ts[parti].replace('<', '\<').replace('>', '\>') |
|
text = '```'.join(ts) |
|
return text |
|
|
|
|
|
def get_images_num_max(model_choice, fun_args, visible_vision_models, do_batching, cli_images_num_max): |
|
images_num_max1 = None |
|
if cli_images_num_max is not None: |
|
images_num_max1 = cli_images_num_max |
|
if model_choice['images_num_max'] is not None: |
|
images_num_max1 = model_choice['images_num_max'] |
|
images_num_max_api = fun_args[len(input_args_list) + eval_func_param_names.index('images_num_max')] |
|
if images_num_max_api is not None: |
|
images_num_max1 = images_num_max_api |
|
if isinstance(images_num_max1, float): |
|
images_num_max1 = int(images_num_max1) |
|
if model_choice['images_num_max'] is not None: |
|
images_num_max1 = model_choice['images_num_max'] |
|
if images_num_max1 is None: |
|
images_num_max1 = images_num_max_dict.get(visible_vision_models) |
|
if images_num_max1 == -1: |
|
|
|
do_batching = True |
|
images_num_max1 = None |
|
elif images_num_max1 is not None and images_num_max1 < -1: |
|
|
|
do_batching = True |
|
images_num_max1 = -images_num_max1 - 1 |
|
|
|
|
|
if images_num_max1 is None or images_num_max1 <= -1: |
|
images_num_max1 = model_choice.get('images_num_max', images_num_max1) |
|
if images_num_max1 is None or images_num_max1 <= -1: |
|
|
|
if model_choice.get('is_actually_vision_model'): |
|
images_num_max1 = images_num_max_dict.get(visible_vision_models, 1) |
|
if images_num_max1 == -1: |
|
|
|
images_num_max1 = 1 |
|
else: |
|
images_num_max1 = images_num_max_dict.get(visible_vision_models, 0) |
|
if images_num_max1 == -1: |
|
|
|
images_num_max1 = 0 |
|
if images_num_max1 < -1: |
|
images_num_max1 = -images_num_max1 - 1 |
|
do_batching = True |
|
|
|
assert images_num_max1 != -1, "Should not be -1 here" |
|
|
|
if images_num_max1 is None: |
|
|
|
images_num_max1 = 0 |
|
|
|
return images_num_max1, do_batching |
|
|
|
|
|
def get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, |
|
langchain_action1, langchain_mode1, kwargs={}, api=False, verbose=False): |
|
if fun1 is None: |
|
yield from _get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, |
|
langchain_action1, kwargs=kwargs, api=api, verbose=verbose) |
|
return |
|
|
|
image_files = fun1.args[len(input_args_list) + eval_func_param_names.index('image_file')] |
|
if image_files is None: |
|
image_files = [] |
|
else: |
|
image_files = image_files.copy() |
|
|
|
import pyexiv2 |
|
meta_data_images = [] |
|
for image_files1 in image_files: |
|
try: |
|
with pyexiv2.Image(image_files1) as img: |
|
metadata = img.read_exif() |
|
except RuntimeError as e: |
|
if 'unknown image type' in str(e): |
|
metadata = {} |
|
else: |
|
raise |
|
if metadata is None: |
|
metadata = {} |
|
meta_data_images.append(metadata) |
|
|
|
fun1_args_list = list(fun1.args) |
|
chosen_model_state = fun1.args[input_args_list.index('model_state')] |
|
base_model = chosen_model_state.get('base_model') |
|
display_name = chosen_model_state.get('display_name') |
|
|
|
visible_vision_models = '' |
|
if kwargs['visible_vision_models']: |
|
|
|
visible_vision_models = kwargs['visible_vision_models'] |
|
if chosen_model_state['is_actually_vision_model']: |
|
visible_vision_models = chosen_model_state['display_name'] |
|
|
|
|
|
|
|
visible_vision_models1 = fun1_args_list[len(input_args_list) + eval_func_param_names.index('visible_vision_models')] |
|
if visible_vision_models1: |
|
if isinstance(visible_vision_models1, list): |
|
visible_vision_models1 = visible_vision_models1[0] |
|
if visible_vision_models1 != 'auto' and visible_vision_models1 in kwargs['all_possible_vision_display_names']: |
|
|
|
visible_vision_models = visible_vision_models1 |
|
|
|
if not visible_vision_models: |
|
visible_vision_models = '' |
|
if isinstance(visible_vision_models, list): |
|
visible_vision_models = visible_vision_models[0] |
|
|
|
force_batching = False |
|
images_num_max, force_batching = get_images_num_max(chosen_model_state, fun1.args, visible_vision_models, |
|
force_batching, kwargs['images_num_max']) |
|
|
|
do_batching = force_batching or len(image_files) > images_num_max or \ |
|
visible_vision_models != display_name and \ |
|
display_name not in kwargs['all_possible_vision_display_names'] |
|
do_batching &= visible_vision_models != '' |
|
do_batching &= len(image_files) > 0 |
|
|
|
|
|
if do_batching and visible_vision_models: |
|
model_states1 = kwargs['model_states'] |
|
model_batch_choice1 = visible_models_to_model_choice(visible_vision_models, model_states1, api=api) |
|
model_batch_choice = model_states1[model_batch_choice1 % len(model_states1)] |
|
images_num_max_batch, do_batching = get_images_num_max(model_batch_choice, fun1.args, visible_vision_models, |
|
do_batching, kwargs['images_num_max']) |
|
|
|
else: |
|
model_batch_choice = None |
|
images_num_max_batch = images_num_max |
|
batch_display_name = model_batch_choice.get('display_name') if model_batch_choice is not None else display_name |
|
|
|
do_batching &= images_num_max_batch not in [0, None] |
|
|
|
if not do_batching: |
|
yield from _get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, |
|
langchain_action1, kwargs=kwargs, api=api, verbose=verbose) |
|
return |
|
else: |
|
instruction = fun1_args_list[len(input_args_list) + eval_func_param_names.index('instruction')] |
|
instruction_nochat = fun1_args_list[len(input_args_list) + eval_func_param_names.index('instruction_nochat')] |
|
instruction = instruction or instruction_nochat or "" |
|
prompt_summary = fun1_args_list[len(input_args_list) + eval_func_param_names.index('prompt_summary')] |
|
if prompt_summary is None: |
|
prompt_summary = kwargs['prompt_summary'] or '' |
|
image_batch_image_prompt = fun1_args_list[len(input_args_list) + eval_func_param_names.index( |
|
'image_batch_image_prompt')] or kwargs['image_batch_image_prompt'] or image_batch_image_prompt0 |
|
image_batch_final_prompt = fun1_args_list[len(input_args_list) + eval_func_param_names.index( |
|
'image_batch_final_prompt')] or kwargs['image_batch_final_prompt'] or image_batch_final_prompt0 |
|
|
|
system_prompt = fun1_args_list[len(input_args_list) + eval_func_param_names.index('system_prompt')] |
|
if system_prompt not in [None, 'None', 'auto']: |
|
system_prompt_xml = f"""\n<system_prompt>\n{system_prompt}\n</system_prompt>\n""" if system_prompt else '' |
|
else: |
|
system_prompt_xml = '' |
|
if langchain_action1 == LangChainAction.QUERY.value: |
|
instruction_batch = image_batch_image_prompt + system_prompt_xml + instruction |
|
instruction_final = image_batch_final_prompt + system_prompt_xml + instruction |
|
prompt_summary_batch = prompt_summary |
|
prompt_summary_final = prompt_summary |
|
elif langchain_action1 == LangChainAction.SUMMARIZE_MAP.value: |
|
instruction_batch = instruction |
|
instruction_final = instruction |
|
prompt_summary_batch = image_batch_image_prompt + system_prompt_xml + prompt_summary |
|
prompt_summary_final = image_batch_final_prompt + system_prompt_xml + prompt_summary |
|
else: |
|
instruction_batch = instruction |
|
instruction_final = instruction |
|
prompt_summary_batch = prompt_summary |
|
prompt_summary_final = prompt_summary |
|
|
|
batch_output_tokens = 0 |
|
batch_time = 0 |
|
batch_input_tokens = 0 |
|
batch_tokenspersec = 0 |
|
batch_results = [] |
|
|
|
text_context_list = fun1_args_list[len(input_args_list) + eval_func_param_names.index('text_context_list')] |
|
text_context_list = str_to_list(text_context_list) |
|
text_context_list_copy = copy.deepcopy(text_context_list) |
|
|
|
fun1_args_list_copy = fun1_args_list.copy() |
|
|
|
for k, v in model_batch_choice.items(): |
|
if k in eval_func_param_names and k in in_model_state_and_evaluate and v is not None: |
|
fun1_args_list_copy[len(input_args_list) + eval_func_param_names.index(k)] = v |
|
for batch in range(0, len(image_files), images_num_max_batch): |
|
fun1_args_list2 = fun1_args_list_copy.copy() |
|
|
|
images_batch = image_files[batch:batch + images_num_max_batch] |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('image_file')] = images_batch |
|
|
|
|
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('images_num_max')] = len(images_batch) |
|
batch_size = len(fun1_args_list2[len(input_args_list) + eval_func_param_names.index('image_file')]) |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('instruction')] = instruction_batch |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('prompt_summary')] = prompt_summary_batch |
|
|
|
if batch_display_name in images_limit_max_new_tokens_list: |
|
max_new_tokens = fun1_args_list2[len(input_args_list) + eval_func_param_names.index('max_new_tokens')] |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('max_new_tokens')] = min( |
|
images_limit_max_new_tokens, max_new_tokens) |
|
|
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('text_context_list')] = [] |
|
|
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('response_format')] = 'text' |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('guided_json')] = None |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('guided_regex')] = None |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('guided_grammar')] = None |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('guided_choice')] = None |
|
|
|
fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('document_subset')] = [] |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('text_context_list')] = [] |
|
|
|
fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('visible_vision_models')] = visible_vision_models |
|
if model_batch_choice: |
|
|
|
fun1_args_list2[0] = model_batch_choice |
|
fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('visible_models')] = visible_vision_models |
|
history1 = deepcopy_by_pickle_object(history) |
|
if not history1: |
|
history1 = [['', '']] |
|
history1[-1][0] = instruction_batch |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('chat_conversation')] = history1 |
|
|
|
history1 = deepcopy_by_pickle_object(history) |
|
history2 = deepcopy_by_pickle_object(history) |
|
fun2 = functools.partial(fun1.func, *tuple(fun1_args_list2), **fun1.keywords) |
|
|
|
text = '' |
|
prompt_raw_saved = '' |
|
save_dict1_saved = {} |
|
error_saved = '' |
|
history_saved = [] |
|
sources_saved = [] |
|
sources_str_saved = '' |
|
llm_answers_saved = {} |
|
image_batch_stream = fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('image_batch_stream')] |
|
if image_batch_stream is None: |
|
image_batch_stream = kwargs['image_batch_stream'] |
|
if not image_batch_stream and not api: |
|
if not history2: |
|
history2 = [['', '']] |
|
if len(image_files) > images_num_max_batch: |
|
history2[-1][1] = '<b>%s querying image %s/%s<b>' % ( |
|
visible_vision_models, 1 + batch, 1 + len(image_files)) |
|
else: |
|
history2[-1][1] = '<b>%s querying image(s)<b>' % visible_vision_models |
|
audio3 = b'' |
|
yield history2, '', [], '', '', [], {}, audio3 |
|
t0_batch = time.time() |
|
for response in _get_response(fun2, history1, chatbot_role1, speaker1, tts_language1, roles_state1, |
|
tts_speed1, |
|
langchain_action1, |
|
kwargs=kwargs, api=api, verbose=verbose): |
|
if image_batch_stream: |
|
yield response |
|
history1, error1, sources1, sources_str1, prompt_raw1, llm_answers1, save_dict1, audio2 = response |
|
prompt_raw_saved = prompt_raw1 |
|
save_dict1_saved = save_dict1 |
|
error_saved = error1 |
|
history_saved = history1 |
|
sources_saved = sources1 |
|
sources_str_saved = sources_str1 |
|
llm_answers_saved = llm_answers1 |
|
text = history1[-1][1] or '' if history1 else '' |
|
batch_input_tokens += save_dict1_saved['extra_dict'].get('num_prompt_tokens', 0) |
|
save_dict1_saved['extra_dict'] = _save_generate_tokens(text, save_dict1_saved['extra_dict']) |
|
ntokens1 = save_dict1_saved['extra_dict'].get('ntokens', 0) |
|
batch_output_tokens += ntokens1 |
|
batch_time += (time.time() - t0_batch) |
|
tokens_per_sec1 = save_dict1_saved['extra_dict'].get('tokens_persecond', 0) |
|
batch_tokenspersec += tokens_per_sec1 |
|
|
|
meta_data = '' |
|
for meta_data_image in meta_data_images[batch:batch + images_num_max_batch]: |
|
if not meta_data_image: |
|
continue |
|
meta_data += '\n'.join( |
|
[f"""<{key}><{value}</{key}>\n""" for key, value in meta_data_image.items()]).strip() + '\n' |
|
response_final = f'<images>\n<batch_name>\nImage {batch}\n</batch_name>\n{meta_data}\n\n{text}\n\n</images>' |
|
|
|
batch_results.append(dict(image_ids=list(range(batch, batch + images_num_max_batch)), |
|
response=text, |
|
response_final=response_final, |
|
prompt_raw=prompt_raw_saved, |
|
save_dict=save_dict1_saved, |
|
error=error_saved, |
|
history=history_saved, |
|
sources=sources_saved, |
|
sources_str=sources_str_saved, |
|
llm_answers=llm_answers_saved, |
|
)) |
|
|
|
|
|
responses = [x['response_final'] for x in batch_results] |
|
batch_tokens_persecond = batch_output_tokens / batch_time if batch_time > 0 else 0 |
|
history1 = deepcopy_by_pickle_object(history) |
|
fun1_args_list2 = fun1_args_list.copy() |
|
|
|
for k, v in chosen_model_state.items(): |
|
if k in eval_func_param_names and k in in_model_state_and_evaluate and v is not None: |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index(k)] = v |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('image_file')] = [] |
|
if not history1: |
|
history1 = [['', '']] |
|
history1[-1][0] = fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('instruction')] = instruction_final |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('chat_conversation')] = history1 |
|
|
|
history1 = deepcopy_by_pickle_object(history) |
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index('prompt_summary')] = prompt_summary_final |
|
if langchain_action1 == LangChainAction.QUERY.value: |
|
instruction = fun1_args_list2[len(input_args_list) + eval_func_param_names.index('instruction')] |
|
if langchain_mode1 == LangChainMode.LLM.value and instruction: |
|
|
|
fun1_args_list2[ |
|
len(input_args_list) + eval_func_param_names.index('instruction')] = '\n\n'.join( |
|
responses) + instruction |
|
else: |
|
|
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index( |
|
'text_context_list')] = responses + text_context_list_copy |
|
else: |
|
|
|
fun1_args_list2[len(input_args_list) + eval_func_param_names.index( |
|
'text_context_list')] = text_context_list_copy + responses |
|
fun2 = functools.partial(fun1.func, *tuple(fun1_args_list2), **fun1.keywords) |
|
for response in _get_response(fun2, history1, chatbot_role1, speaker1, tts_language1, roles_state1, |
|
tts_speed1, langchain_action1, kwargs=kwargs, api=api, verbose=verbose): |
|
response_list = list(response) |
|
save_dict1 = response_list[6] |
|
if 'extra_dict' in save_dict1: |
|
if 'num_prompt_tokens' in save_dict1['extra_dict']: |
|
save_dict1['extra_dict']['batch_vision_visible_model'] = batch_display_name |
|
|
|
save_dict1['extra_dict']['batch_num_prompt_tokens'] = batch_input_tokens |
|
save_dict1['extra_dict']['batch_ntokens'] = batch_output_tokens |
|
save_dict1['extra_dict']['batch_tokens_persecond'] = batch_tokens_persecond |
|
if batch_display_name == display_name: |
|
save_dict1['extra_dict']['num_prompt_tokens'] += batch_input_tokens |
|
|
|
history1new = response_list[0] |
|
if history1new and len(history1new) > 0 and len(history1new[0]) == 2 and history1new[-1][1]: |
|
save_dict1['extra_dict'] = _save_generate_tokens(history1new[-1][1], |
|
save_dict1['extra_dict']) |
|
save_dict1['extra_dict']['ntokens'] += batch_output_tokens |
|
save_dict1['extra_dict']['batch_results'] = batch_results |
|
response_list[6] = save_dict1 |
|
yield tuple(response_list) |
|
return |
|
|
|
|
|
def _get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, |
|
langchain_action1, kwargs={}, api=False, verbose=False): |
|
""" |
|
bot that consumes history for user input |
|
instruction (from input_list) itself is not consumed by bot |
|
:return: |
|
""" |
|
error = '' |
|
sources = [] |
|
save_dict = dict() |
|
output_no_refs = '' |
|
sources_str = '' |
|
prompt_raw = '' |
|
llm_answers = {} |
|
|
|
audio0, audio1, no_audio, generate_speech_func_func = \ |
|
prepare_audio(chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, langchain_action1, |
|
kwargs=kwargs, verbose=verbose) |
|
|
|
if not fun1: |
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio1 |
|
return |
|
try: |
|
for output_fun in fun1(): |
|
output = output_fun['response'] |
|
output_no_refs = output_fun['response_no_refs'] |
|
sources = output_fun['sources'] |
|
sources_iter = [] |
|
sources_str = output_fun['sources_str'] |
|
sources_str_iter = '' |
|
prompt_raw = output_fun['prompt_raw'] |
|
prompt_raw_iter = '' |
|
llm_answers = output_fun['llm_answers'] |
|
save_dict = output_fun.get('save_dict', {}) |
|
save_dict_iter = {} |
|
|
|
bot_message = fix_text_for_gradio(output, fix_latex_dollars=not api, fix_angle_brackets=not api) |
|
history[-1][1] = bot_message |
|
|
|
if generate_speech_func_func is not None: |
|
while True: |
|
audio1, sentence, sentence_state = generate_speech_func_func(output_no_refs, is_final=False) |
|
if audio0 is not None: |
|
yield history, error, sources_iter, sources_str_iter, prompt_raw_iter, llm_answers, save_dict_iter, audio0 |
|
audio0 = None |
|
yield history, error, sources_iter, sources_str_iter, prompt_raw_iter, llm_answers, save_dict_iter, audio1 |
|
if not sentence: |
|
|
|
break |
|
else: |
|
yield history, error, sources_iter, sources_str_iter, prompt_raw_iter, llm_answers, save_dict_iter, audio0 |
|
if generate_speech_func_func: |
|
|
|
audio1, sentence, sentence_state = generate_speech_func_func(output_no_refs, is_final=True) |
|
if audio0 is not None: |
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio0 |
|
else: |
|
audio1 = None |
|
|
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio1 |
|
except StopIteration: |
|
|
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, no_audio |
|
raise |
|
except RuntimeError as e: |
|
if "generator raised StopIteration" in str(e): |
|
|
|
history.pop() |
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, no_audio |
|
else: |
|
if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: |
|
history[-1][1] = '' |
|
yield history, str(e), sources, sources_str, prompt_raw, llm_answers, save_dict, no_audio |
|
raise |
|
except Exception as e: |
|
|
|
ex = "Exception: %s" % str(e) |
|
if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: |
|
history[-1][1] = '' |
|
yield history, ex, sources, sources_str, prompt_raw, llm_answers, save_dict, no_audio |
|
raise |
|
finally: |
|
|
|
|
|
pass |
|
return |
|
|
|
|
|
def prepare_audio(chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, langchain_action1, kwargs={}, |
|
verbose=False): |
|
assert kwargs |
|
from tts_sentence_parsing import init_sentence_state |
|
sentence_state = init_sentence_state() |
|
if langchain_action1 in [LangChainAction.EXTRACT.value]: |
|
|
|
generate_speech_func_func = None |
|
audio0 = None |
|
audio1 = None |
|
no_audio = None |
|
elif kwargs['tts_model'].startswith('microsoft') and speaker1 not in [None, "None"]: |
|
audio1 = None |
|
from tts import get_speaker_embedding |
|
speaker_embedding = get_speaker_embedding(speaker1, kwargs['model_tts'].device) |
|
|
|
from tts_utils import prepare_speech, get_no_audio |
|
sr = 16000 |
|
audio0 = prepare_speech(sr=sr) |
|
no_audio = get_no_audio(sr=sr) |
|
generate_speech_func_func = functools.partial(kwargs['generate_speech_func'], |
|
speaker=speaker1, |
|
speaker_embedding=speaker_embedding, |
|
sentence_state=sentence_state, |
|
return_as_byte=kwargs['return_as_byte'], |
|
sr=sr, |
|
tts_speed=tts_speed1, |
|
verbose=verbose) |
|
elif kwargs['tts_model'].startswith('tts_models/') and chatbot_role1 not in [None, "None"]: |
|
audio1 = None |
|
from tts_utils import prepare_speech, get_no_audio |
|
from tts_coqui import get_latent |
|
sr = 24000 |
|
audio0 = prepare_speech(sr=sr) |
|
no_audio = get_no_audio(sr=sr) |
|
latent = get_latent(roles_state1[chatbot_role1], model=kwargs['model_xtt']) |
|
generate_speech_func_func = functools.partial(kwargs['generate_speech_func'], |
|
latent=latent, |
|
language=tts_language1, |
|
sentence_state=sentence_state, |
|
return_as_byte=kwargs['return_as_byte'], |
|
sr=sr, |
|
tts_speed=tts_speed1, |
|
verbose=verbose) |
|
else: |
|
generate_speech_func_func = None |
|
audio0 = None |
|
audio1 = None |
|
no_audio = None |
|
return audio0, audio1, no_audio, generate_speech_func_func |
|
|
|
|
|
def prep_bot(*args, retry=False, which_model=0, kwargs_eval={}, plain_api=False, kwargs={}, verbose=False): |
|
""" |
|
|
|
:param args: |
|
:param retry: |
|
:param which_model: identifies which model if doing model_lock |
|
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list |
|
:return: last element is True if should run bot, False if should just yield history |
|
""" |
|
assert kwargs |
|
isize = len(input_args_list) + 1 |
|
|
|
|
|
args_list = list(args).copy() |
|
model_state1 = args_list[-isize] |
|
my_db_state1 = args_list[-isize + 1] |
|
selection_docs_state1 = args_list[-isize + 2] |
|
requests_state1 = args_list[-isize + 3] |
|
roles_state1 = args_list[-isize + 4] |
|
history = args_list[-1] |
|
if not history: |
|
history = [] |
|
|
|
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] |
|
if prompt_type1 == no_model_str: |
|
|
|
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] = None |
|
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')] |
|
max_time1 = args_list[eval_func_param_names.index('max_time')] |
|
stream_output1 = args_list[eval_func_param_names.index('stream_output')] |
|
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] |
|
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')] |
|
document_subset1 = args_list[eval_func_param_names.index('document_subset')] |
|
h2ogpt_key1 = args_list[eval_func_param_names.index('h2ogpt_key')] |
|
chat_conversation1 = args_list[eval_func_param_names.index('chat_conversation')] |
|
valid_key = is_valid_key(kwargs['enforce_h2ogpt_api_key'], |
|
kwargs['enforce_h2ogpt_ui_key'], |
|
kwargs['h2ogpt_api_keys'], h2ogpt_key1, |
|
requests_state1=requests_state1) |
|
chatbot_role1 = args_list[eval_func_param_names.index('chatbot_role')] |
|
speaker1 = args_list[eval_func_param_names.index('speaker')] |
|
tts_language1 = args_list[eval_func_param_names.index('tts_language')] |
|
tts_speed1 = args_list[eval_func_param_names.index('tts_speed')] |
|
|
|
dummy_return = history, None, langchain_mode1, my_db_state1, requests_state1, \ |
|
valid_key, h2ogpt_key1, \ |
|
max_time1, stream_output1, chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, \ |
|
langchain_action1, [] |
|
|
|
if not plain_api and (model_state1['model'] is None or model_state1['model'] == no_model_str): |
|
|
|
return dummy_return |
|
|
|
args_list = args_list[:-isize] |
|
if not history: |
|
if verbose: |
|
print("No history", flush=True) |
|
return dummy_return |
|
instruction1 = history[-1][0] |
|
if retry and history: |
|
|
|
history = get_llm_history(history) |
|
instruction1 = history[-1][0] if history and history[-1] and len(history[-1]) == 2 else None |
|
if history and history[-1]: |
|
history[-1][1] = None |
|
if not instruction1: |
|
return dummy_return |
|
elif not instruction1: |
|
if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): |
|
|
|
return dummy_return |
|
elif len(history) > 0 and history[-1][1] not in [None, '']: |
|
|
|
|
|
return dummy_return |
|
|
|
from gen import evaluate, evaluate_fake |
|
evaluate_local = evaluate if valid_key else functools.partial(evaluate_fake, langchain_action=langchain_action1) |
|
|
|
|
|
prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1, |
|
which_model=which_model, **kwargs) |
|
|
|
args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 |
|
args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 |
|
context1 = args_list[eval_func_param_names.index('context')] |
|
|
|
chat_conversation1 = merge_chat_conversation_history(chat_conversation1, history) |
|
args_list[eval_func_param_names.index('chat_conversation')] = chat_conversation1 |
|
|
|
if 'visible_models' in model_state1 and model_state1['visible_models'] is not None: |
|
assert isinstance(model_state1['visible_models'], (int, str)) |
|
args_list[eval_func_param_names.index('visible_models')] = model_state1['visible_models'] |
|
if 'visible_vision_models' in model_state1 and model_state1['visible_vision_models'] is not None: |
|
assert isinstance(model_state1['visible_vision_models'], (int, str)) |
|
args_list[eval_func_param_names.index('visible_vision_models')] = model_state1['visible_vision_models'] |
|
if 'h2ogpt_key' in model_state1 and model_state1['h2ogpt_key'] is not None: |
|
|
|
assert isinstance(model_state1['h2ogpt_key'], str) |
|
args_list[eval_func_param_names.index('h2ogpt_key')] = model_state1['h2ogpt_key'] |
|
elif not args_list[eval_func_param_names.index('h2ogpt_key')]: |
|
|
|
|
|
args_list[eval_func_param_names.index('h2ogpt_key')] = kwargs['h2ogpt_key'] |
|
|
|
|
|
|
|
image_files = args_list[eval_func_param_names.index('image_file')] |
|
if isinstance(image_files, str): |
|
image_files = [image_files] |
|
if image_files is None: |
|
image_files = [] |
|
video_files = args_list[eval_func_param_names.index('video_file')] |
|
if isinstance(video_files, str): |
|
video_files = [video_files] |
|
if video_files is None: |
|
video_files = [] |
|
|
|
image_files.extend(video_files) |
|
|
|
image_files_to_delete = [] |
|
b2imgs = [] |
|
for img_file_one in image_files: |
|
str_type = check_input_type(img_file_one) |
|
if str_type == 'unknown': |
|
continue |
|
|
|
img_file_path = os.path.join(tempfile.gettempdir(), 'image_file_%s' % str(uuid.uuid4())) |
|
if str_type == 'url': |
|
img_file_one = download_image(img_file_one, img_file_path) |
|
|
|
image_files_to_delete.append(img_file_one) |
|
elif str_type == 'base64': |
|
from vision.utils_vision import base64_to_img |
|
img_file_one = base64_to_img(img_file_one, img_file_path) |
|
|
|
image_files_to_delete.append(img_file_one) |
|
else: |
|
|
|
pass |
|
if img_file_one is not None: |
|
b2imgs.append(img_file_one) |
|
|
|
args_list[eval_func_param_names.index('image_file')] = b2imgs |
|
|
|
|
|
images_file_path = os.path.join(tempfile.gettempdir(), 'image_path_%s' % str(uuid.uuid4())) |
|
|
|
image_files = args_list[eval_func_param_names.index('image_file')] |
|
image_resolution = args_list[eval_func_param_names.index('image_resolution')] |
|
image_format = args_list[eval_func_param_names.index('image_format')] |
|
video_frame_period = args_list[eval_func_param_names.index('video_frame_period')] |
|
if video_frame_period is not None: |
|
video_frame_period = int(video_frame_period) |
|
extract_frames = args_list[eval_func_param_names.index('extract_frames')] or kwargs.get('extract_frames', 20) |
|
rotate_align_resize_image = args_list[eval_func_param_names.index('rotate_align_resize_image')] or kwargs.get( |
|
'rotate_align_resize_image', True) |
|
process_args = (image_files, images_file_path) |
|
process_kwargs = dict(resolution=image_resolution, |
|
image_format=image_format, |
|
rotate_align_resize_image=rotate_align_resize_image, |
|
video_frame_period=video_frame_period, |
|
extract_frames=extract_frames, |
|
verbose=verbose) |
|
if image_files and kwargs['function_server']: |
|
from function_client import call_function_server |
|
image_files = call_function_server('0.0.0.0', kwargs['function_server_port'], 'process_file_list', |
|
process_args, process_kwargs, |
|
use_disk=True, use_pickle=True, |
|
function_api_key=kwargs['function_api_key'], |
|
verbose=verbose) |
|
else: |
|
image_files = process_file_list(*process_args, **process_kwargs) |
|
args_list[eval_func_param_names.index('image_file')] = image_files |
|
|
|
|
|
|
|
args_list[0] = instruction1 |
|
args_list[2] = context1 |
|
|
|
|
|
|
|
for k in eval_func_param_names: |
|
if k in in_model_state_and_evaluate: |
|
|
|
continue |
|
if k in model_state1 and model_state1[k] is not None: |
|
args_list[eval_func_param_names.index(k)] = model_state1[k] |
|
|
|
eval_args = (model_state1, my_db_state1, selection_docs_state1, requests_state1, roles_state1) |
|
assert len(eval_args) == len(input_args_list) |
|
fun1 = functools.partial(evaluate_local, *eval_args, *tuple(args_list), **kwargs_eval) |
|
|
|
return history, fun1, langchain_mode1, my_db_state1, requests_state1, \ |
|
valid_key, h2ogpt_key1, \ |
|
max_time1, stream_output1, \ |
|
chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, \ |
|
langchain_action1, image_files_to_delete |
|
|
|
|
|
def choose_exc(x, is_public=True): |
|
|
|
if is_public: |
|
return x |
|
else: |
|
return x |
|
|
|
|
|
def bot(*args, retry=False, kwargs_evaluate={}, kwargs={}, db_type=None, dbs=None, verbose=False): |
|
history, fun1, langchain_mode1, db1, requests_state1, \ |
|
valid_key, h2ogpt_key1, \ |
|
max_time1, stream_output1, \ |
|
chatbot_role1, speaker1, tts_language1, roles_state1, tts_speed1, \ |
|
image_files_to_delete, \ |
|
langchain_action1 = prep_bot(*args, retry=retry, kwargs_eval=kwargs_evaluate, kwargs=kwargs, verbose=verbose) |
|
save_dict = dict() |
|
error = '' |
|
error_with_str = '' |
|
sources = [] |
|
history_str_old = '' |
|
error_old = '' |
|
sources_str = None |
|
from tts_utils import get_no_audio |
|
no_audio = get_no_audio() |
|
audios = [] |
|
last_yield = None |
|
try: |
|
tgen0 = time.time() |
|
for res in get_response(fun1, history, chatbot_role1, speaker1, tts_language1, roles_state1, |
|
tts_speed1, |
|
langchain_action1, |
|
langchain_mode1, |
|
kwargs=kwargs, |
|
api=False, |
|
verbose=verbose, |
|
): |
|
do_yield = False |
|
history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio1 = res |
|
error_with_str = get_accordion_named(choose_exc(error), "Generate Error", |
|
font_size=2) if error not in ['', None, 'None'] else '' |
|
|
|
|
|
history_str = str(history) |
|
could_yield = ( |
|
history_str != history_str_old or |
|
error != error_old and |
|
(error not in noneset or |
|
error_old not in noneset)) |
|
if kwargs['gradio_ui_stream_chunk_size'] <= 0: |
|
do_yield |= could_yield |
|
else: |
|
delta_history = abs(len(history_str) - len(history_str_old)) |
|
|
|
enough_data = delta_history > kwargs['gradio_ui_stream_chunk_size'] or (error != error_old) |
|
beyond_min_time = last_yield is None or \ |
|
last_yield is not None and \ |
|
(time.time() - last_yield) > kwargs['gradio_ui_stream_chunk_min_seconds'] |
|
do_yield |= enough_data and beyond_min_time |
|
|
|
enough_time = last_yield is None or \ |
|
last_yield is not None and \ |
|
(time.time() - last_yield) > kwargs['gradio_ui_stream_chunk_seconds'] |
|
do_yield |= enough_time and could_yield |
|
|
|
if stream_output1 and do_yield: |
|
audio1 = combine_audios(audios, audio=audio1, sr=24000 if chatbot_role1 else 16000, |
|
expect_bytes=kwargs['return_as_byte'], verbose=verbose) |
|
audios = [] |
|
|
|
yield history, error, audio1 |
|
history_str_old = history_str |
|
error_old = error |
|
last_yield = time.time() |
|
else: |
|
audios.append(audio1) |
|
|
|
if time.time() - tgen0 > max_time1 + 10: |
|
if verbose: |
|
print("Took too long bot: %s" % (time.time() - tgen0), flush=True) |
|
break |
|
|
|
|
|
final_audio = combine_audios(audios, audio=no_audio, |
|
expect_bytes=kwargs['return_as_byte'], verbose=verbose) |
|
if error_with_str: |
|
if history and history[-1] and len(history[-1]) == 2 and error_with_str: |
|
if not history[-1][1]: |
|
history[-1][1] = error_with_str |
|
else: |
|
|
|
history.append((None, error_with_str)) |
|
if kwargs['append_sources_to_chat'] and sources_str: |
|
history.append((None, sources_str)) |
|
|
|
yield history, error, final_audio |
|
except BaseException as e: |
|
print("evaluate_nochat exception: %s: %s" % (str(e), str(args)), flush=True) |
|
raise |
|
finally: |
|
clear_torch_cache(allow_skip=True) |
|
clear_embeddings(langchain_mode1, db_type, db1, dbs) |
|
for image_file1 in image_files_to_delete: |
|
if os.path.isfile(image_file1): |
|
remove(image_file1) |
|
|
|
|
|
if 'extra_dict' not in save_dict: |
|
save_dict['extra_dict'] = {} |
|
save_dict['valid_key'] = valid_key |
|
save_dict['h2ogpt_key'] = h2ogpt_key1 |
|
if requests_state1: |
|
save_dict['extra_dict'].update(requests_state1) |
|
else: |
|
save_dict['extra_dict'].update(dict(username='NO_REQUEST')) |
|
save_dict['error'] = error |
|
save_dict['sources'] = sources |
|
save_dict['which_api'] = 'bot' |
|
save_dict['save_dir'] = kwargs['save_dir'] |
|
save_generate_output(**save_dict) |
|
|
|
|
|
def is_from_ui(requests_state1): |
|
return isinstance(requests_state1, dict) and 'username' in requests_state1 and requests_state1['username'] |
|
|
|
|
|
def is_valid_key(enforce_h2ogpt_api_key, enforce_h2ogpt_ui_key, h2ogpt_api_keys, h2ogpt_key1, requests_state1=None): |
|
from_ui = is_from_ui(requests_state1) |
|
|
|
if from_ui and not enforce_h2ogpt_ui_key: |
|
|
|
return 'not enforced' |
|
elif not from_ui and not enforce_h2ogpt_api_key: |
|
|
|
return 'not enforced' |
|
else: |
|
valid_key = False |
|
if isinstance(h2ogpt_api_keys, list) and h2ogpt_key1 in h2ogpt_api_keys: |
|
|
|
valid_key = True |
|
elif isinstance(h2ogpt_api_keys, str) and os.path.isfile(h2ogpt_api_keys): |
|
with filelock.FileLock(h2ogpt_api_keys + '.lock'): |
|
with open(h2ogpt_api_keys, 'rt') as f: |
|
h2ogpt_api_keys = json.load(f) |
|
if h2ogpt_key1 in h2ogpt_api_keys: |
|
valid_key = True |
|
return valid_key |
|
|
|
|
|
def get_one_key(h2ogpt_api_keys, enforce_h2ogpt_api_key): |
|
if not enforce_h2ogpt_api_key: |
|
|
|
return None |
|
if isinstance(h2ogpt_api_keys, list) and h2ogpt_api_keys: |
|
return h2ogpt_api_keys[0] |
|
elif isinstance(h2ogpt_api_keys, str) and os.path.isfile(h2ogpt_api_keys): |
|
with filelock.FileLock(h2ogpt_api_keys + '.lock'): |
|
with open(h2ogpt_api_keys, 'rt') as f: |
|
h2ogpt_api_keys = json.load(f) |
|
if h2ogpt_api_keys: |
|
return h2ogpt_api_keys[0] |
|
|
|
|
|
def get_model_max_length(model_state1, model_state0): |
|
if model_state1 and not isinstance(model_state1["tokenizer"], str): |
|
tokenizer = model_state1["tokenizer"] |
|
elif model_state0 and not isinstance(model_state0["tokenizer"], str): |
|
tokenizer = model_state0["tokenizer"] |
|
else: |
|
tokenizer = None |
|
if tokenizer is not None: |
|
return int(tokenizer.model_max_length) |
|
else: |
|
return 2000 |
|
|
|
|
|
def get_llm_history(history): |
|
|
|
if history is None: |
|
history = [] |
|
for ii in range(len(history) - 1, -1, -1): |
|
if history[ii] and history[ii][0] is not None: |
|
last_user_ii = ii |
|
history = history[:last_user_ii + 1] |
|
break |
|
return history |
|
|
|
|
|
def gen1_fake(fun1, history): |
|
error = '' |
|
sources = [] |
|
sources_str = '' |
|
prompt_raw = '' |
|
llm_answers = {} |
|
save_dict = dict() |
|
audio1 = None |
|
yield history, error, sources, sources_str, prompt_raw, llm_answers, save_dict, audio1 |
|
return |
|
|
|
|
|
def merge_chat_conversation_history(chat_conversation1, history): |
|
|
|
if chat_conversation1: |
|
chat_conversation1 = str_to_list(chat_conversation1) |
|
for conv1 in chat_conversation1: |
|
assert isinstance(conv1, (list, tuple)) |
|
assert len(conv1) == 2 |
|
|
|
if isinstance(history, list): |
|
|
|
if chat_conversation1: |
|
|
|
history = chat_conversation1 + history.copy() |
|
elif chat_conversation1: |
|
history = chat_conversation1 |
|
else: |
|
history = [] |
|
return history |
|
|
|
|
|
def update_langchain_mode_paths(selection_docs_state1): |
|
dup = selection_docs_state1['langchain_mode_paths'].copy() |
|
for k, v in dup.items(): |
|
if k not in selection_docs_state1['langchain_modes']: |
|
selection_docs_state1['langchain_mode_paths'].pop(k) |
|
for k in selection_docs_state1['langchain_modes']: |
|
if k not in selection_docs_state1['langchain_mode_types']: |
|
|
|
selection_docs_state1['langchain_mode_types'][k] = LangChainTypes.PERSONAL.value |
|
return selection_docs_state1 |
|
|
|
|
|
|
|
def my_db_state_done(state): |
|
if isinstance(state, dict): |
|
for langchain_mode_db, db_state in state.items(): |
|
scratch_data = state[langchain_mode_db] |
|
if langchain_mode_db in langchain_modes_intrinsic: |
|
if len(scratch_data) == length_db1() and hasattr(scratch_data[0], 'delete_collection') and \ |
|
scratch_data[1] == scratch_data[2]: |
|
|
|
scratch_data[0].delete_collection() |
|
|
|
scratch_data[0] = None |
|
del scratch_data[0] |
|
|
|
|
|
def process_audio(file1, t1=0, t2=30): |
|
|
|
from pydub import AudioSegment |
|
|
|
t1 = t1 * 1000 |
|
t2 = t2 * 1000 |
|
newAudio = AudioSegment.from_wav(file1)[t1:t2] |
|
new_file = file1 + '.new.wav' |
|
newAudio.export(new_file, format="wav") |
|
return new_file |
|
|
|
|
|
def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): |
|
allow = False |
|
allow |= langchain_action1 not in [LangChainAction.QUERY.value, |
|
LangChainAction.IMAGE_QUERY.value, |
|
LangChainAction.IMAGE_CHANGE.value, |
|
LangChainAction.IMAGE_GENERATE.value, |
|
LangChainAction.IMAGE_STYLE.value, |
|
] |
|
allow |= document_subset1 in [DocumentSubset.TopKSources.name] |
|
if langchain_mode1 in [LangChainMode.LLM.value]: |
|
allow = False |
|
return allow |
|
|
|
|
|
def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0, global_scope=False, **kwargs): |
|
assert kwargs |
|
if not prompt_type1 or which_model != 0: |
|
|
|
prompt_type1 = kwargs.get('prompt_type', prompt_type1) |
|
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) |
|
|
|
if not global_scope: |
|
if not prompt_type1 or which_model != 0: |
|
prompt_type1 = model_state1.get('prompt_type', prompt_type1) |
|
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) |
|
|
|
if not prompt_dict1 or which_model != 0: |
|
|
|
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) |
|
if not global_scope: |
|
if not prompt_dict1 or which_model != 0: |
|
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) |
|
if not global_scope and not prompt_type1: |
|
|
|
prompt_type1 = unknown_prompt_type |
|
return prompt_type1, prompt_dict1 |
|
|
|
|
|
def get_fun_with_dict_str_plain(default_kwargs, kwargs, **kwargs_evaluate_nochat): |
|
fun_with_dict_str_plain = functools.partial(evaluate_nochat, |
|
default_kwargs1=default_kwargs, |
|
str_api=True, |
|
plain_api=True, |
|
kwargs=kwargs, |
|
**kwargs_evaluate_nochat, |
|
) |
|
return fun_with_dict_str_plain |
|
|