|
import functools |
|
import os |
|
import math |
|
import csv |
|
import datetime |
|
|
|
import filelock |
|
import gradio as gr |
|
|
|
from utils import is_gradio_version4 |
|
|
|
|
|
def get_chatbot_name(base_model, display_name, model_path_llama, inference_server='', prompt_type='', model_label_prefix='', debug=False): |
|
|
|
|
|
|
|
|
|
|
|
label_postfix = '' |
|
if not debug: |
|
inference_server = '' |
|
else: |
|
inference_server = ' : ' + inference_server |
|
if base_model == 'llama': |
|
model_path_llama = os.path.basename(model_path_llama) |
|
if model_path_llama.endswith('?download=true'): |
|
model_path_llama = model_path_llama.replace('?download=true', '') |
|
label = f'{model_label_prefix} [Model: {model_path_llama}{inference_server}]' |
|
else: |
|
if base_model == 'mixtral-8x7b-32768': |
|
base_model = 'groq:mixtral-8x7b-32768' |
|
if display_name: |
|
|
|
base_model = display_name |
|
label = f'{model_label_prefix} [Model: {base_model}{inference_server}]' |
|
label += label_postfix |
|
return label |
|
|
|
|
|
def get_avatars(base_model, model_path_llama, inference_server=''): |
|
if base_model == 'llama': |
|
base_model = model_path_llama |
|
if inference_server is None: |
|
inference_server = '' |
|
|
|
model_base = os.getenv('H2OGPT_MODEL_BASE', 'models/') |
|
human_avatar = "human.jpg" |
|
if 'h2ogpt-gm'.lower() in base_model.lower(): |
|
bot_avatar = "h2oai.png" |
|
elif 'llava-' in base_model.lower(): |
|
bot_avatar = "llava.png" |
|
elif 'mistralai'.lower() in base_model.lower() or \ |
|
'mistral'.lower() in base_model.lower() or \ |
|
'mixtral'.lower() in base_model.lower(): |
|
bot_avatar = "mistralai.png" |
|
elif '01-ai/Yi-'.lower() in base_model.lower(): |
|
bot_avatar = "yi.svg" |
|
elif 'wizard' in base_model.lower(): |
|
bot_avatar = "wizard.jpg" |
|
elif 'openchat' in base_model.lower(): |
|
bot_avatar = "openchat.png" |
|
elif 'vicuna' in base_model.lower(): |
|
bot_avatar = "vicuna.jpeg" |
|
elif 'longalpaca' in base_model.lower(): |
|
bot_avatar = "longalpaca.png" |
|
elif 'llama2-70b-chat' in base_model.lower(): |
|
bot_avatar = "meta.png" |
|
elif 'llama2-13b-chat' in base_model.lower(): |
|
bot_avatar = "meta.png" |
|
elif 'llama2-7b-chat' in base_model.lower(): |
|
bot_avatar = "meta.png" |
|
elif 'llama2' in base_model.lower(): |
|
bot_avatar = "lama2.jpeg" |
|
elif 'llama-2' in base_model.lower(): |
|
bot_avatar = "lama2.jpeg" |
|
elif 'llama' in base_model.lower(): |
|
bot_avatar = "lama.jpeg" |
|
elif 'openai' in base_model.lower() or 'openai' in inference_server.lower(): |
|
bot_avatar = "openai.png" |
|
elif 'hugging' in base_model.lower(): |
|
bot_avatar = "hf-logo.png" |
|
elif 'claude' in base_model.lower(): |
|
bot_avatar = "anthropic.jpeg" |
|
elif 'gemini' in base_model.lower(): |
|
bot_avatar = "google.png" |
|
else: |
|
bot_avatar = "h2oai.png" |
|
|
|
bot_avatar = os.path.join(model_base, bot_avatar) |
|
human_avatar = os.path.join(model_base, human_avatar) |
|
|
|
human_avatar = human_avatar if os.path.isfile(human_avatar) else None |
|
bot_avatar = bot_avatar if os.path.isfile(bot_avatar) else None |
|
return human_avatar, bot_avatar |
|
|
|
|
|
def ratingfn1(): |
|
return 1 |
|
|
|
|
|
def ratingfn2(): |
|
return 2 |
|
|
|
|
|
def ratingfn3(): |
|
return 3 |
|
|
|
|
|
def ratingfn4(): |
|
return 4 |
|
|
|
|
|
def ratingfn5(): |
|
return 5 |
|
|
|
|
|
def submit_review(review_text, text_output, text_output2, *text_outputs1, reviews_file=None, num_model_lock=None, |
|
do_info=True): |
|
if reviews_file is None: |
|
if do_info: |
|
gr.Info('No review file') |
|
return '' |
|
|
|
chatbots = [text_output, text_output2] + list(text_outputs1) |
|
last_chatbots = [x[-1] for x in chatbots if x] |
|
|
|
now = datetime.datetime.now() |
|
with filelock.FileLock(reviews_file + '.lock'): |
|
with open(reviews_file, 'a', newline='') as csvfile: |
|
writer = csv.writer(csvfile) |
|
writer.writerow([review_text, *last_chatbots, now]) |
|
if do_info: |
|
gr.Info('Review submitted!') |
|
return '' |
|
|
|
|
|
def make_chatbots(output_label0, output_label0_model2, **kwargs): |
|
visible_models = kwargs['visible_models'] |
|
all_models = kwargs['all_possible_display_names'] |
|
visible_ratings = kwargs['visible_ratings'] |
|
reviews_file = kwargs['reviews_file'] or 'reviews.csv' |
|
|
|
text_outputs = [] |
|
chat_kwargs = [] |
|
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160 |
|
for model_state_locki, model_state_lock in enumerate(kwargs['model_states']): |
|
output_label = get_chatbot_name(model_state_lock["base_model"], |
|
model_state_lock["display_name"], |
|
model_state_lock['llamacpp_dict']["model_path_llama"], |
|
model_state_lock["inference_server"], |
|
model_state_lock["prompt_type"], |
|
model_label_prefix=kwargs['model_label_prefix'], |
|
debug=bool(os.environ.get('DEBUG_MODEL_LOCK', 0))) |
|
if kwargs['avatars']: |
|
avatar_images = get_avatars(model_state_lock["base_model"], |
|
model_state_lock['llamacpp_dict']["model_path_llama"], |
|
model_state_lock["inference_server"]) |
|
else: |
|
avatar_images = None |
|
chat_kwargs.append(dict(render_markdown=kwargs.get('render_markdown', True), |
|
label=output_label, |
|
show_label=kwargs.get('visible_chatbot_label', True), |
|
elem_classes='chatsmall', |
|
height=kwargs['height'] or 400, |
|
min_width=min_width, |
|
avatar_images=avatar_images, |
|
likeable=True, |
|
latex_delimiters=[], |
|
show_copy_button=kwargs['show_copy_button'], |
|
visible=kwargs['model_lock'] and (visible_models is None or |
|
model_state_locki in visible_models or |
|
all_models[model_state_locki] in visible_models |
|
))) |
|
|
|
|
|
if visible_models and kwargs['model_lock_layout_based_upon_initial_visible']: |
|
len_visible = len(visible_models) |
|
else: |
|
len_visible = len(kwargs['model_states']) |
|
if kwargs['model_lock_columns'] == -1: |
|
kwargs['model_lock_columns'] = len_visible |
|
if kwargs['model_lock_columns'] is None: |
|
kwargs['model_lock_columns'] = 3 |
|
|
|
ncols = kwargs['model_lock_columns'] |
|
if kwargs['model_states'] == 0: |
|
nrows = 0 |
|
else: |
|
nrows = math.ceil(len_visible / kwargs['model_lock_columns']) |
|
|
|
if kwargs['model_lock_columns'] == 0: |
|
|
|
pass |
|
elif nrows <= 1: |
|
with gr.Row(): |
|
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']): |
|
text_outputs.append(gr.Chatbot(**chat_kwargs1)) |
|
elif nrows == kwargs['model_states']: |
|
with gr.Row(): |
|
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']): |
|
text_outputs.append(gr.Chatbot(**chat_kwargs1)) |
|
elif nrows > 0: |
|
len_chatbots = len(kwargs['model_states']) |
|
nrows = math.ceil(len_chatbots / kwargs['model_lock_columns']) |
|
for nrowi in range(nrows): |
|
with gr.Row(): |
|
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): |
|
if mii < nrowi * len_chatbots / nrows or mii >= (1 + nrowi) * len_chatbots / nrows: |
|
continue |
|
text_outputs.append(gr.Chatbot(**chat_kwargs1)) |
|
if len(kwargs['model_states']) > 0: |
|
assert len(text_outputs) == len(kwargs['model_states']) |
|
|
|
if kwargs['avatars']: |
|
avatar_images = get_avatars(kwargs["base_model"], kwargs['llamacpp_dict']["model_path_llama"], |
|
kwargs["inference_server"]) |
|
else: |
|
avatar_images = None |
|
no_model_lock_chat_kwargs = dict(render_markdown=kwargs.get('render_markdown', True), |
|
show_label=kwargs.get('visible_chatbot_label', True), |
|
elem_classes='chatsmall', |
|
height=kwargs['height'] or 400, |
|
min_width=min_width, |
|
show_copy_button=kwargs['show_copy_button'], |
|
avatar_images=avatar_images, |
|
latex_delimiters=[], |
|
) |
|
with gr.Row(): |
|
text_output = gr.Chatbot(label=output_label0, |
|
visible=not kwargs['model_lock'], |
|
**no_model_lock_chat_kwargs, |
|
likeable=True, |
|
) |
|
text_output2 = gr.Chatbot(label=output_label0_model2, |
|
visible=False and not kwargs['model_lock'], |
|
**no_model_lock_chat_kwargs, |
|
likeable=True, |
|
) |
|
|
|
chatbots = [text_output, text_output2] + text_outputs |
|
|
|
with gr.Row(visible=visible_ratings): |
|
review_textbox = gr.Textbox(visible=True, label="Review", placeholder="Type your review...", scale=4) |
|
rating_text_output = gr.Textbox(elem_id="text_output", visible=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
rating1 = gr.Button(value='⭑', variant='outline-primary', scale=1, elem_id="rating1", size="sm") |
|
rating2 = gr.Button(value='⭑', variant='outline-primary', scale=1, elem_id="rating2", size="sm") |
|
rating3 = gr.Button(value='⭑', variant='outline-primary', scale=1, elem_id="rating3", size="sm") |
|
rating4 = gr.Button(value='⭑', variant='outline-primary', scale=1, elem_id="rating4", size="sm") |
|
rating5 = gr.Button(value='⭑', variant='outline-primary', scale=1, elem_id="rating5", size="sm") |
|
|
|
review_js1 = """ |
|
function highlightButtons() { |
|
var element = document.getElementById("rating1"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating2"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating3"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating4"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating5"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
} |
|
""" |
|
|
|
review_js2 = """ |
|
function highlightButtons() { |
|
var element = document.getElementById("rating1"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating2"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating3"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating4"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating5"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
} |
|
""" |
|
review_js3 = """ |
|
function highlightButtons() { |
|
var element = document.getElementById("rating1"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating2"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating3"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating4"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
|
|
var element = document.getElementById("rating5"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
} |
|
""" |
|
review_js4 = """ |
|
function highlightButtons() { |
|
var element = document.getElementById("rating1"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating2"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating3"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating4"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating5"); |
|
// element.style.backgroundColor = "rgba(173, 181, 189, 0.5)"; |
|
element.style.color = "rgba(173, 181, 189, 0.5)"; |
|
} |
|
""" |
|
review_js5 = """ |
|
function highlightButtons() { |
|
var element = document.getElementById("rating1"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating2"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating3"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating4"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
|
|
var element = document.getElementById("rating5"); |
|
// element.style.backgroundColor = "#ffa41c"; |
|
element.style.color = "#ffa41c"; |
|
} |
|
""" |
|
if is_gradio_version4: |
|
rating1.click(ratingfn1, outputs=rating_text_output, js=review_js1) |
|
rating2.click(ratingfn2, outputs=rating_text_output, js=review_js2) |
|
rating3.click(ratingfn3, outputs=rating_text_output, js=review_js3) |
|
rating4.click(ratingfn4, outputs=rating_text_output, js=review_js4) |
|
rating5.click(ratingfn5, outputs=rating_text_output, js=review_js5) |
|
else: |
|
rating1.click(ratingfn1, outputs=rating_text_output, _js=review_js1) |
|
rating2.click(ratingfn2, outputs=rating_text_output, _js=review_js2) |
|
rating3.click(ratingfn3, outputs=rating_text_output, _js=review_js3) |
|
rating4.click(ratingfn4, outputs=rating_text_output, _js=review_js4) |
|
rating5.click(ratingfn5, outputs=rating_text_output, _js=review_js5) |
|
|
|
submit_review_btn = gr.Button("Submit Review", scale=1) |
|
submit_review_func = functools.partial(submit_review, |
|
reviews_file=reviews_file if reviews_file else None, |
|
num_model_lock=len(chatbots)) |
|
submit_review_btn.click(submit_review_func, |
|
inputs=[review_textbox, rating_text_output, |
|
text_output, text_output2] + text_outputs, |
|
outputs=review_textbox) |
|
|
|
|
|
def on_like(like_data: gr.LikeData): |
|
submit_review(str(like_data.liked) + "," + str(like_data.target.label), *tuple([['', like_data.value], []]), |
|
reviews_file=reviews_file, num_model_lock=len(chatbots), do_info=False) |
|
|
|
for chatbot in chatbots: |
|
chatbot.like(on_like) |
|
|
|
return text_output, text_output2, text_outputs |
|
|