import os from typing import Iterator import gradio as gr from text_generation import Client model_id = 'mistralai/Mistral-7B-Instruct-v0.1' API_URL = "https://api-inference.huggingface.co/models/" + model_id HF_TOKEN = os.environ.get('HF_READ_TOKEN', False) client = Client( API_URL, headers={'Authorization': f"Bearer {HF_TOKEN}"} ) EOS_STRING = "" EOT_STRING = "" def get_prompt(message, chat_history, system_prompt): texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] do_strip = False for user_input, response in chat_history: user_input = user_input.strip() if do_strip else user_input do_strip = True texts.append(f"{user_input} [/INST] {response.strip()} [INST] ") message = message.strip() if do_strip else message texts.append(f"{message} [/INST]") return ''.join(texts) def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.1, top_p=0.9, top_k=50): prompt = get_prompt(message, chat_history, system_prompt) generate_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature ) stream = client.generate_stream(prompt, **generate_kwargs) output = '' for response in stream: if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]): return output else: output += response.token.text yield output return output DEFAULT_SYSTEM_PROMPT = """ You are Ricky. You are an AI assistant, you are moderately-polite and give only true information. You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning. If you think there might not be a correct answer, you say so. Since you are autoregressive, each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context, assumptions, and step-by-step thinking BEFORE you try to answer a question. You are an AI developed by MCES10 Software the website is www.mces10-software.com. The CEO is MCES10. You are based on the Mistral-7B-Instruct-v0.1. You ask what the person's name is if they say hello. MCES10 Software has made apps named To-List a brilliant to-do list app. Web Development Tutorials also known as W.D.T which teaches you how to code websites. There are AI Hub which is in development which is the gateway to everything AI. """ MAX_MAX_NEW_TOKENS = 4096 DEFAULT_MAX_NEW_TOKENS = 256 MAX_INPUT_TOKEN_LENGTH = 4000 DESCRIPTION = "Ricky AI" def clear_and_save_textbox(message): return '', message def display_input(message, history=[]): history.append((message, '')) return history def delete_prev_fn(history=[]): try: message, _ = history.pop() except IndexError: message = '' return history, message or '' def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k): if max_new_tokens > MAX_MAX_NEW_TOKENS: raise ValueError history = history_with_input[:-1] generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) try: first_response = next(generator) yield history + [(message, first_response)] except StopIteration: yield history + [(message, '')] for response in generator: yield history + [(message, response)] def process_example(message): generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50) for x in generator: pass return '', x def check_input_token_length(message, chat_history, system_prompt): input_token_length = len(message) + len(chat_history) if input_token_length > MAX_INPUT_TOKEN_LENGTH: raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.") with gr.Blocks(theme='Taithrah/Minimal') as demo: gr.Markdown(DESCRIPTION) with gr.Group(): chatbot = gr.Chatbot(label='RickyAI based on Mistral-7B-Instruct-v0.1') with gr.Row(): textbox = gr.Textbox( container=False, show_label=False, placeholder='Hi, Ricky', scale=10 ) submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0) with gr.Row(): retry_button = gr.Button('Retry', variant='secondary') undo_button = gr.Button('Undo', variant='secondary') clear_button = gr.Button('Clear', variant='secondary') saved_input = gr.State() with gr.Accordion(label='Advanced options', open=False): system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False) max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1) top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10) textbox.submit( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( fn=check_input_token_length, inputs=[saved_input, chatbot, system_prompt], api_name=False, queue=False, ).success( fn=generate, inputs=[ saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, ], outputs=chatbot, api_name=False, ) button_event_preprocess = submit_button.click( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( fn=check_input_token_length, inputs=[saved_input, chatbot, system_prompt], api_name=False, queue=False, ).success( fn=generate, inputs=[ saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, ], outputs=chatbot, api_name=False, ) retry_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( fn=generate, inputs=[ saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, ], outputs=chatbot, api_name=False, ) undo_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=lambda x: x, inputs=[saved_input], outputs=textbox, api_name=False, queue=False, ) clear_button.click( fn=lambda: ([], ''), outputs=[chatbot, saved_input], queue=False, api_name=False, ) demo.queue(max_size=32).launch(show_api=False)